tel: +48 728 438 076
email: piotr.hyzy@eviden.com
Początek¶
Rundka - sprawdzenie mikrofonów¶
Kilka zasad¶
- W razie problemĂłw => chat, potem SMS i telefon, NIE mail
- Materiały szkoleniowe
- Wszystkie pytania sÄ… ok
- Reguła Vegas
- SĹ‚uchawki
- Kamerki
- Chat
- Zgłaszamy wyjścia na początku danego dnia, także pożary, wszystko na chacie
- By default mute podczas wykład
- Przerwy (praca 08:30 - 16:30)
- blok 09:00 - 10:00
- kawowa 10:00 - 10:15 (15')
- blok 10:15 - 11:45
- kawowa 11:45 - 12:00
- blok 12:00 - 13:30
- obiad 13:30 - 14:00
- blok 14:00 - 15:00
- kawowa 15:00 - 15:10
- blok 15:10 - 16:00
- wszystkie czasy sÄ… plus/minus 10'
- Jak zadawać pytanie? 1) przerwanie 2) pytanie na chacie 3) podniesienie wirtualnej ręki
- IDE => dowolne
- Każde ćwiczenie w osobnym pliku/Notebooku
- Nie zapraszamy innych osĂłb
- Zaczynamy punktualnie
- Ćwiczenia w dwójkach, rotacje, ask for help
Pytest¶
pytest Fundamentals¶
Pytest is a powerful and easy-to-use testing framework for Python. This module will introduce the basics of using Pytest, including setting up test cases, running them, and interpreting the results. By the end of this section, you'll understand the foundational concepts of Pytest and be able to write your first test cases.
! pip install pytest
Requirement already satisfied: pytest in ./.venv/lib/python3.12/site-packages (8.2.2) Requirement already satisfied: iniconfig in ./.venv/lib/python3.12/site-packages (from pytest) (2.0.0) Requirement already satisfied: packaging in ./.venv/lib/python3.12/site-packages (from pytest) (24.1) Requirement already satisfied: pluggy<2.0,>=1.5 in ./.venv/lib/python3.12/site-packages (from pytest) (1.5.0)
System Under Test (SUT): factorial(num)
¶
The System Under Test (SUT) refers to the specific function or module being tested. In this case, it's the factorial(num)
function, which calculates the factorial of a number.
# %%writefile mathutils.py
def factorial(num: int):
if not isinstance(num, int):
raise TypeError('Argument must be int')
if num == 0:
return 1
else:
return factorial(num-1) * num
# n! = (n-1)! * n
# 3! = 2! * 3 = 1! * 2 * 3 = 0! * 1 * 2 * 3 = 1 * 1 * 2 * 3 = 6
factorial(0)
1
factorial(1)
1
factorial(3)
6
Understanding the SUT is critical because test cases are built to validate its behavior against expected outcomes.
Example Project Structure¶
Organizing your project correctly is essential for writing maintainable and scalable tests. A typical Pytest-compatible project structure might look like this:
mathutils/
__init__.py # Package initialization
factorial_utils.py # Factorial-related utilities
optimization.py # Optimization-related code
domain/
__init__.py # another module
vector.py # Vector Model
base.py # Base code
tests/
__init__.py # Package initialization for tests
test_factorial_utils.py # Tests for factorial utilities
test_optimization.py # Tests for optimization
requirements.txt # Test Dependency management
main.py # Main app file
setup.py # Project metadata and installation setup
requirements.txt # Dependency management
Key Points:
- Code files are stored in a
mathutils/
directory. - Test files are in the
tests/
directory and follow thetest_*.py
naming convention. setup.py
andrequirements.txt
help manage project dependencies and installation.
Writing Assertions¶
An assertion is a condition that must evaluate to True
for the test to pass:
cond = True
assert cond, 'cond is True'
# This is equivalent to:
if not cond:
raise AssertionError('cond is True')
Basic Tests¶
Tests in Pytest are written as functions. Assertions are used to verify that the code behaves as expected.
%%writefile pytest_fundamentals.py
import pytest
from mathutils import factorial
def test_factorial_of_one():
assert factorial(1) == 1
def test_factorial_of_three():
got = factorial(3)
expected = 7 # intentionally wrong value to show how test failing, the right value is 6
assert expected == got, 'my error message'
def test_raises_typererror_for_float():
with pytest.raises(TypeError):
factorial(3.5) # Expecting a TypeError for non-integer input
Overwriting pytest_fundamentals.py
Explanation:
test_factorial_of_one
: Verifies the factorial of 1.test_factorial_of_three
: Contains a deliberate failure (expected
is incorrect) to demonstrate test results.test_raises_typeerror_for_invalid_argument
: Validates that the function raises the correct exception for invalid input.
Launching Tests¶
Pytest provides a simple command-line interface to execute tests. To run tests in a specific file:
! pytest pytest_fundamentals.py
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 3 items pytest_fundamentals.py .F. [100%] =================================== FAILURES =================================== ___________________________ test_factorial_of_three ____________________________ def test_factorial_of_three(): got = factorial(3) expected = 7 # intentionally wrong value to show how test failing, the right value is 6 > assert expected == got, 'my error message' E AssertionError: my error message E assert 7 == 6 pytest_fundamentals.py:11: AssertionError =========================== short test summary info ============================ FAILED pytest_fundamentals.py::test_factorial_of_three - AssertionError: my error message ========================= 1 failed, 2 passed in 0.57s ==========================
Launching Tests with -q
¶
The -q
flag provides a more concise output:
! pytest pytest_fundamentals.py -q
.F. [100%] =================================== FAILURES =================================== ___________________________ test_factorial_of_three ____________________________ def test_factorial_of_three(): got = factorial(3) expected = 7 # intentionally wrong value to show how test failing, the right value is 6 > assert expected == got, 'my error message' E AssertionError: my error message E assert 7 == 6 pytest_fundamentals.py:11: AssertionError =========================== short test summary info ============================ FAILED pytest_fundamentals.py::test_factorial_of_three - AssertionError: my error message 1 failed, 2 passed in 0.46s
Useful Switches¶
Pytest includes several useful command-line options to enhance debugging:
Stop after the first failure (-x
):
! pytest pytest_fundamentals.py -x
This stops the test run as soon as a failure is encountered.
! pytest pytest_fundamentals.py -x
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 3 items pytest_fundamentals.py .F =================================== FAILURES =================================== ___________________________ test_factorial_of_three ____________________________ def test_factorial_of_three(): got = factorial(3) expected = 7 # intentionally wrong value to show how test failing, the right value is 6 > assert expected == got, 'my error message' E AssertionError: my error message E assert 7 == 6 pytest_fundamentals.py:11: AssertionError =========================== short test summary info ============================ FAILED pytest_fundamentals.py::test_factorial_of_three - AssertionError: my error message !!!!!!!!!!!!!!!!!!!!!!!!!! stopping after 1 failures !!!!!!!!!!!!!!!!!!!!!!!!!!! ========================= 1 failed, 1 passed in 0.35s ==========================
Print local variables for failing tests (-l
):
! pytest pytest_fundamentals.py -l
This displays the values of local variables to help debug.
! pytest pytest_fundamentals.py -l
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 3 items pytest_fundamentals.py .F. [100%] =================================== FAILURES =================================== ___________________________ test_factorial_of_three ____________________________ def test_factorial_of_three(): got = factorial(3) expected = 7 # intentionally wrong value to show how test failing, the right value is 6 > assert expected == got, 'my error message' E AssertionError: my error message E assert 7 == 6 expected = 7 got = 6 pytest_fundamentals.py:11: AssertionError =========================== short test summary info ============================ FAILED pytest_fundamentals.py::test_factorial_of_three - AssertionError: my error message ========================= 1 failed, 2 passed in 0.35s ==========================
Autodiscover All Tests¶
Pytest can automatically discover and execute all tests in a project. Test files must follow the test_*.py
naming convention.
To autodiscover and run tests:
! pytest
! pytest
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 0 items ============================ no tests ran in 0.02s =============================
Tips:
- Ensure all test files follow the
test_*.py
naming pattern. - Use directories like
tests/
to organize your test files.
Exercise: 🏠 Dividing¶
Write as many tests as possible for the function below.
def div(a, b):
return a / b
Solution¶
%%writefile pytest_div.py
"""
### Naming Test Functions
When creating test functions in Pytest, it's a good practice to name the function in a way that clearly communicates:
1. The **action** being tested (e.g., a function call, operation, or scenario).
2. The **expected output** or result of that action.
#### Naming Convention Example
- **Test function name**: `test_dividing_two_integers_should_give_a_float`
- **Action**: "dividing two integers" (describes the operation being tested)
- **Expected Output**: "should give a float" (clarifies the expected result)
This naming approach improves readability, making it easier for others to understand what the test is verifying.
#### Test Name Structure
The typical structure for a test function name is:
`test_<action>_<expected_output>`
#### Examples:
- `test_dividing_4_by_2_gives_2`: Testing that dividing 4 by 2 results in 2.
- `test_dividing_two_integers_should_give_a_float`: Testing that dividing two integers always returns a float.
- `test_raises_error_on_division_by_zero`: Testing that dividing by zero raises an error.
"""
import math
import pytest
# Define a simple division function to be tested
def div(a, b):
return a / b
# Test that dividing 4 by 2 returns 2
def test_dividing_4_by_2_gives_2():
assert div(4, 2) == 2
# Test that dividing two integers results in a float
def test_dividing_two_integers_should_give_a_float():
# given / arrange - set up any necessary context (not needed here)
pass
# when / action - perform the operation
got = div(3, 2)
# then / assert - check that the result matches the expectation
expected = 1.5
assert got == expected
# Test that dividing two integers always returns a float, even if the result is an integer
def test_dividing_two_integers_returns_a_float_even_when_result_is_integer():
assert isinstance(div(4, 2), float)
# Test that dividing by zero raises a ZeroDivisionError
def test_raises_error_on_division_by_zero():
with pytest.raises(ZeroDivisionError):
div(2, 0)
# Test that dividing two negative numbers gives a positive result
def test_dividing_negative_numbers():
assert div(-3, -2) == 1.5
# Test that dividing infinity by infinity results in NaN (not a number)
def test_dividing_infinities_gives_nan():
# Check using math.isnan since direct comparison with NaN doesn't work
assert math.isnan(div(float('inf'), float('inf')))
# Test that dividing infinity by a finite number still results in infinity
def test_dividing_infinity_gives_infinity():
inf = float('inf')
assert div(inf, 2) == inf
# Test that dividing infinity by zero raises a ZeroDivisionError
def test_dividing_infinity_by_zero_raises_an_error():
inf = float('inf')
with pytest.raises(ZeroDivisionError):
div(inf, 0)
# Test that dividing a boolean value works, with True treated as 1 and False as 0
def test_dividing_boolean_works():
assert div(True, 2) == 0.5
# Test that dividing a very small number (epsilon) by 2 results in 0
def test_dividing_epsilon_by_two_gives_zero():
assert div(5e-324, 2) == 0.0
# Test that dividing a huge number by a small number results in infinity
def test_dividing_huge_numbers_results_in_infinity():
assert div(1e308, 0.5) == float('inf')
# Test that attempting to divide two lists raises a TypeError
def test_dividing_two_lists_raises_an_error():
with pytest.raises(TypeError):
div([3, 2, 5], [1, 2, 3])
# Test that the division operator can be overridden in custom classes
def test_you_can_override_division_operator():
class Dividable:
def __truediv__(self, other):
return 42 # Custom division logic
d = Dividable()
assert div(d, d) == 42 # Confirm overridden behavior
# Test that dividing two instances of a custom object without division logic raises a TypeError
def test_dividing_custom_objects_raises_TypeError():
class MyClass:
pass
m = MyClass()
with pytest.raises(TypeError):
div(m, m)
Overwriting pytest_div.py
! pytest pytest_div.py -q
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 14 items pytest_div.py .............. [100%] ============================== 14 passed in 0.17s ==============================
from contextlib import contextmanager
import io, sys
@contextmanager
def suppress_output():
stdout = sys.stdout
sys.stdout = None
try:
yield stdout
finally:
sys.stdout = stdout
print("A")
with suppress_output() as s:
print("B")
s.write("asdf\n")
print("C")
A asdf C
Exercise: 🏠 @suppress_output¶
Implementation:¶
# %%writefile suppress_output_module.py
from contextlib import contextmanager
import io, sys
@contextmanager
def suppress_output():
stdout = sys.stdout
sys.stdout = None
try:
yield stdout
finally:
sys.stdout = stdout
Example Usage¶
print("A")
with suppress_output() as s:
print("B")
s.write("asdf\n")
print("C")
What to Test?¶
When testing the functionality involving sys.stdout
and error handling, consider the following scenarios:
Changing
sys.stdout
in PointB
- Verify that
sys.stdout
is correctly changed within the scope ofB
.
- Verify that
Error Handling in Point
B
- Ensure that if an error is raised in
B
, thesuppress_output
mechanism does not suppress or alter the exception.
- Ensure that if an error is raised in
Restoring
sys.stdout
in PointC
- Confirm that
sys.stdout
is restored to its original within the scope ofC
.
- Confirm that
Restoring
sys.stdout
in PointC
even in case of an error inB
- Test that
sys.stdout
is restored even ifC
encounters an error during execution of scope `B.
- Test that
Ensuring
s
Matches the Originalsys.stdout
- Validate that the variable
s
holds a reference to the originalsys.stdout
.
- Validate that the variable
Solution¶
%%writefile pytest_suppressoutput.py
import sys
import pytest
from suppress_output_module import suppress_output
def test_stdout_should_change_inside_suppress_output():
before = sys.stdout # Save the original sys.stdout for comparison
with suppress_output():
# assert sys.stdout is None # This may give a false positive (FP) because the assertion is too strong.
assert sys.stdout is not before # This may give a false negative (FN) because the assertion is too weak
def test_stdout_should_be_resotred_after_suppress_output():
before = sys.stdout
with suppress_output():
pass
assert sys.stdout is before
def test_suppress_output_propagates_exceptions():
class ExampleException(Exception):
pass
with pytest.raises(ExampleException):
with suppress_output():
raise ExampleException
def test_suppress_output_restores_original_stdout_even_in_case_of_an_exception():
class ExampleException(Exception):
pass
before = sys.stdout
try:
with suppress_output():
raise ExampleException
except ExampleException:
pass
assert sys.stdout is before
def test_suppress_output_returns_original_stdout():
before = sys.stdout
with suppress_output() as s:
assert s is before
Here's a detailed explanation of the comments in the test function:
Test Function Code¶
def test_stdout_should_change_inside_suppress_output():
before = sys.stdout # Save the original sys.stdout for comparison
with suppress_output():
# assert sys.stdout is None # This may give a false positive (FP) because the assertion is too strong.
assert sys.stdout is not before # This may give a false negative (FN) because the assertion is too weak
Explanation of Comments¶
Comment 1:¶
# assert sys.stdout is None # This may give a false positive (FP) because the assertion is too strong.
- Reason: The code inside
suppress_output()
setssys.stdout
toNone
. However, directly asserting thatsys.stdout
isNone
may not always be reliable in broader contexts:- If the implementation changes slightly (e.g.,
sys.stdout
is replaced with a dummyio.StringIO
object instead ofNone
), this assertion would fail, even though the behavior of suppressing output remains correct. - This makes the assertion too strict and prone to failing unnecessarily in valid cases.
- If the implementation changes slightly (e.g.,
Comment 2:¶
# assert sys.stdout is not before # This may give a false negative (FN) because the assertion is too weak
- Reason: This assertion checks only that
sys.stdout
has changed from its original state (before
). While this ensuressys.stdout
is modified, it doesn't confirm how it has been changed or whether it matches the intended behavior of thesuppress_output()
function:- If
sys.stdout
is set to a different value (e.g., a mock object), this assertion would still pass, even if the behavior of suppressing output is incorrect. - This makes the assertion too lenient and may miss issues (false negatives).
- If
Suggested Improvement¶
To balance the strictness and flexibility of the assertions, you could test for specific behavior rather than relying on the exact value of sys.stdout
:
def test_stdout_should_change_inside_suppress_output():
before = sys.stdout
with suppress_output() as captured_stdout:
assert sys.stdout is not before # Ensure sys.stdout changes
assert sys.stdout is None or isinstance(sys.stdout, io.TextIOWrapper) # Allow some flexibility
assert captured_stdout is before # Confirm the original stdout is captured correctly
Key Points:¶
- Testing specific behavior (e.g.,
sys.stdout
changes and the original is captured) is often more robust than asserting exact values. - The balance between strict and lenient assertions is critical for avoiding both false positives and false negatives.
Fixtures¶
Fixtures are a key feature in Pytest that allow you to set up and tear down resources required for your tests. They provide a way to share setup code across multiple tests, making them more efficient and maintainable. Unlike traditional setUp
and tearDown
methods in unit testing frameworks like unittest
, fixtures in Pytest are more flexible and reusable.
What Are Fixtures?¶
- Setup Code: Fixtures are functions that prepare some state or resources required for your tests.
- Teardown Code: Pytest ensures that resources are properly cleaned up after the test, even if it fails.
- Reusability: Fixtures can be shared across multiple test functions, classes, or modules.
- Scope: Fixtures can have different scopes (
function
,class
,module
,session
), determining their lifespan.
Unique Temporary Directory¶
Pytest provides a built-in fixture called tmpdir
that creates a unique temporary directory for each test function. This is useful for testing file-related operations.
%%writefile pytest_tmpdir.py
import pytest
def test_needsfiles(tmpdir):
print(tmpdir)
print(type(tmpdir))
assert False
Overwriting pytest_tmpdir.py
! pytest pytest_tmpdir.py -q
F [100%] =================================== FAILURES =================================== _______________________________ test_needsfiles ________________________________ tmpdir = local('/private/var/folders/w9/9hmtpfzj64v0x0j841ny9y280000gn/T/pytest-of-a563420/pytest-5/test_needsfiles0') def test_needsfiles(tmpdir): print(tmpdir) print(type(tmpdir)) > assert False E assert False pytest_tmpdir.py:6: AssertionError ----------------------------- Captured stdout call ----------------------------- /private/var/folders/w9/9hmtpfzj64v0x0j841ny9y280000gn/T/pytest-of-a563420/pytest-5/test_needsfiles0 <class '_pytest._py.path.LocalPath'> =========================== short test summary info ============================ FAILED pytest_tmpdir.py::test_needsfiles - assert False 1 failed in 0.42s
Output:
- The
tmpdir
fixture provides a temporary directory as apy.path.local.LocalPath
object. - The directory is unique for each test run and is cleaned up automatically after the test.
Key Points:
tmpdir
is isolated for each test, preventing side effects between tests.- It simplifies testing code that interacts with the file system.
Listing All Fixtures¶
Pytest allows you to view all available fixtures using the --fixtures
command.
! pytest --fixtures
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 0 items cache -- .venv/lib/python3.12/site-packages/_pytest/cacheprovider.py:560 Return a cache object that can persist state between testing sessions. capsysbinary -- .venv/lib/python3.12/site-packages/_pytest/capture.py:1003 Enable bytes capturing of writes to ``sys.stdout`` and ``sys.stderr``. capfd -- .venv/lib/python3.12/site-packages/_pytest/capture.py:1030 Enable text capturing of writes to file descriptors ``1`` and ``2``. capfdbinary -- .venv/lib/python3.12/site-packages/_pytest/capture.py:1057 Enable bytes capturing of writes to file descriptors ``1`` and ``2``. capsys -- .venv/lib/python3.12/site-packages/_pytest/capture.py:976 Enable text capturing of writes to ``sys.stdout`` and ``sys.stderr``. doctest_namespace [session scope] -- .venv/lib/python3.12/site-packages/_pytest/doctest.py:738 Fixture that returns a :py:class:`dict` that will be injected into the namespace of doctests. pytestconfig [session scope] -- .venv/lib/python3.12/site-packages/_pytest/fixtures.py:1338 Session-scoped fixture that returns the session's :class:`pytest.Config` object. record_property -- .venv/lib/python3.12/site-packages/_pytest/junitxml.py:284 Add extra properties to the calling test. record_xml_attribute -- .venv/lib/python3.12/site-packages/_pytest/junitxml.py:307 Add extra xml attributes to the tag for the calling test. record_testsuite_property [session scope] -- .venv/lib/python3.12/site-packages/_pytest/junitxml.py:345 Record a new ``<property>`` tag as child of the root ``<testsuite>``. tmpdir_factory [session scope] -- .venv/lib/python3.12/site-packages/_pytest/legacypath.py:303 Return a :class:`pytest.TempdirFactory` instance for the test session. tmpdir -- .venv/lib/python3.12/site-packages/_pytest/legacypath.py:310 Return a temporary directory path object which is unique to each test function invocation, created as a sub directory of the base temporary directory. caplog -- .venv/lib/python3.12/site-packages/_pytest/logging.py:602 Access and control log capturing. monkeypatch -- .venv/lib/python3.12/site-packages/_pytest/monkeypatch.py:33 A convenient fixture for monkey-patching. recwarn -- .venv/lib/python3.12/site-packages/_pytest/recwarn.py:32 Return a :class:`WarningsRecorder` instance that records all warnings emitted by test functions. tmp_path_factory [session scope] -- .venv/lib/python3.12/site-packages/_pytest/tmpdir.py:242 Return a :class:`pytest.TempPathFactory` instance for the test session. tmp_path -- .venv/lib/python3.12/site-packages/_pytest/tmpdir.py:257 Return a temporary directory path object which is unique to each test function invocation, created as a sub directory of the base temporary directory. ------------------ fixtures defined from anyio.pytest_plugin ------------------- anyio_backend [module scope] -- .venv/lib/python3.12/site-packages/anyio/pytest_plugin.py:132 no docstring available anyio_backend_name -- .venv/lib/python3.12/site-packages/anyio/pytest_plugin.py:137 no docstring available anyio_backend_options -- .venv/lib/python3.12/site-packages/anyio/pytest_plugin.py:145 no docstring available ------------------- fixtures defined from pytest_cov.plugin -------------------- no_cover -- .venv/lib/python3.12/site-packages/pytest_cov/plugin.py:429 A pytest fixture to disable coverage. cov -- .venv/lib/python3.12/site-packages/pytest_cov/plugin.py:434 A pytest fixture to provide access to the underlying coverage object. ============================ no tests ran in 0.07s =============================
Common Built-in Fixtures:¶
capsys
: Captures output written tosys.stdout
andsys.stderr
.monkeypatch
: Allows you to modify or replace code for testing.tmpdir
: Provides a unique temporary directory for each test.pytestconfig
: Gives access to the Pytest configuration object.
Benefits:
- Helps discover useful fixtures provided by Pytest and third-party plugins.
- Saves time by reusing existing functionality.
Using capsys
¶
The capsys
fixture allows you to capture text written to sys.stdout
and sys.stderr
during a test.
%%writefile test_capsys.py
def test_using_capsys(capsys):
print('asdf')
out, err = capsys.readouterr()
print('out', out)
assert out == 'asdf\n'
Overwriting test_capsys.py
! pytest test_capsys.py
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 1 item test_capsys.py . [100%] ============================== 1 passed in 0.07s ===============================
Key Points:
- Use
capsys.readouterr()
to capture and read the output. - Separate captured
stdout
(out
) andstderr
(err
). - Useful for testing CLI tools or functions that print output.
Implementing Fixtures¶
You can define your own fixtures using the @pytest.fixture
decorator. Fixtures encapsulate setup logic, allowing you to create reusable components for test preparation. Pytest will automatically execute fixtures before the test functions that request them.
Steps to Create Your Own Fixture¶
- Import
pytest
:- Ensure you have
pytest
imported in your test file.
- Ensure you have
- Define a Function with the Setup Logic:
- Create a function that contains the necessary setup steps for your tests.
- Annotate with
@pytest.fixture
:- Use the
@pytest.fixture
decorator to mark the function as a fixture.
- Use the
- Return the Required Object:
- The function should return the object or resource that will be passed to the test.
%%writefile pytest_implementing_fixtures.py
from time import sleep
import pytest
@pytest.fixture
def empty_list():
print('preparing database')
sleep(2)
return []
def test_a(empty_list):
print(empty_list)
empty_list.append(2)
print(empty_list)
def test_b(empty_list):
print(empty_list)
empty_list.append(2)
print(empty_list)
Overwriting pytest_implementing_fixtures.py
! pytest pytest_implementing_fixtures.py -sv
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 2 items pytest_implementing_fixtures.py::test_a preparing database [] [2] PASSED pytest_implementing_fixtures.py::test_b preparing database [] [2] PASSED ============================== 2 passed in 4.07s ===============================
Key Points:
- Fixtures are instantiated before the test runs.
- The
return
value of the fixture is passed to the test function as an argument.
Sharing Fixture Instances¶
You can control how often a fixture is created by setting its scope. The scope determines the lifespan of a fixture and how many times it is invoked during a test session. Pytest provides four built-in scopes:
Fixture Scopes¶
function
(Default)- Definition: A new instance of the fixture is created for each test function that uses it.
- Use Case: Ideal for tests that require an isolated, fresh setup for every test.
- Example:
%%writefile pytest_fixture_scope_function.py
import pytest
@pytest.fixture(scope='function')
def resource():
print("Setting up resource for each test")
return "function_resource"
def test_a(resource):
assert resource == "function_resource"
resource = 'something else'
def test_b(resource):
assert resource == "function_resource"
resource = 'something else'
Overwriting pytest_fixture_scope_function.py
! pytest pytest_fixture_scope_function.py -sv
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 2 items pytest_fixture_scope_function.py::test_a Setting up resource for each test PASSED pytest_fixture_scope_function.py::test_b Setting up resource for each test PASSED ============================== 2 passed in 0.08s ===============================
class
- Definition: A single instance of the fixture is created and shared among all tests in a class.
- Use Case: Useful for tests within a class that share the same setup but need isolation from other classes.
- Example:
%%writefile pytest_fixture_scope_class.py
import pytest
@pytest.fixture(scope='class')
def resource():
print("Setting up resource for the class")
return 'class_resource'
class TestExample:
def test_a(self, resource):
print(id(resource))
assert resource == 'class_resource'
resource = 'dddddd' # create local copy if resource variable, function scope
def test_b(self, resource):
assert resource == 'class_resource'
Overwriting pytest_fixture_scope_class.py
! pytest pytest_fixture_scope_class.py -sv
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 2 items pytest_fixture_scope_class.py::TestExample::test_a Setting up resource for the class 4365607152 PASSED pytest_fixture_scope_class.py::TestExample::test_b PASSED ============================== 2 passed in 0.07s ===============================
%%writefile pytest_fixture_scope_class1.py
import pytest
@pytest.fixture(scope='class')
def resource():
print("Setting up resource for the class")
return {"class_resource": "class_resource"}
class TestExample1:
def test_a(self, resource):
assert resource == {"class_resource": "class_resource"}
# resource.update(new_key="new_value")
def test_b(self, resource):
assert resource == {"class_resource": "class_resource"}
resource.update(new_key="new_value")
class TestExample2:
def test_a(self, resource):
assert resource == {"class_resource": "class_resource"}
Overwriting pytest_fixture_scope_class1.py
! pytest pytest_fixture_scope_class1.py -sv
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 3 items pytest_fixture_scope_class1.py::TestExample1::test_a Setting up resource for the class PASSED pytest_fixture_scope_class1.py::TestExample1::test_b PASSED pytest_fixture_scope_class1.py::TestExample2::test_a Setting up resource for the class PASSED ============================== 3 passed in 0.09s ===============================
module
- Definition: A single instance of the fixture is created and shared across all tests in a module.
- Use Case: Suitable for module-level resources that are expensive to set up but can be shared safely among tests.
- Example:
%%writefile pytest_fixture_scope_module.py
import pytest
# fixture might be in separate file
@pytest.fixture(scope='module')
def resource():
print("Setting up resource for the module")
return "module_resource"
def test_a(resource):
assert resource == "module_resource"
def test_b(resource):
assert resource == "module_resource"
Overwriting pytest_fixture_scope_module.py
! pytest pytest_fixture_scope_module.py -sv
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 2 items pytest_fixture_scope_module.py::test_a Setting up resource for the module PASSED pytest_fixture_scope_module.py::test_b PASSED ============================== 2 passed in 0.08s ===============================
session
- Definition: A single instance of the fixture is created and shared across the entire test session, spanning multiple modules.
- Use Case: Best for global resources that need to be initialized once and reused across all tests (e.g., database connections, external services).
- Example:
%%writefile conftest.py
import pytest
@pytest.fixture(scope='session')
def resource():
print("Setting up resource for the session")
return "session_resource"
Overwriting conftest.py
%%writefile test_fixture_scope_session1.py
import pytest
def test_a(resource):
assert resource == "session_resource"
def test_b(resource):
assert resource == "session_resource"
Overwriting test_fixture_scope_session1.py
%%writefile test_fixture_scope_session2.py
import pytest
def test_c(resource):
assert resource == "session_resource"
def test_d(resource):
assert resource == "session_resource"
Overwriting test_fixture_scope_session2.py
! pytest ./ -sv
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 5 items test_capsys.py::test_using_capsys out asdf PASSED test_fixture_scope_session1.py::test_a Setting up resource for the session PASSED test_fixture_scope_session1.py::test_b PASSED test_fixture_scope_session2.py::test_c PASSED test_fixture_scope_session2.py::test_d PASSED ============================== 5 passed in 0.09s ===============================
Practical Example: Sharing Fixture Instances¶
Exercise: 🏠 DB¶
Tip
A fixture, can use other fixture
### Two slow!!!
create_database()
test_a()
create_database()
test_b()
create_database()
test_c()
### Optimal
create_database()
reset_database()
test_a()
reset_database()
test_b()
reset_database()
test_c()
Code:¶
# %%writefile pytest_db.py
import pytest
from time import sleep
# Tego kodu nie modyfiklujemy!
def create_database():
print("create_database")
sleep(2)
return []
def reset_database(db):
print("reset_database")
del db[:]
# Wasze fixtures
....
# Wasze trzy testy: test_a, test_b i test_c
def test_one(db):
print('test_one')
db.append(2)
assert db == [2]
def test_two(db):
print('test_two')
assert db == []
def test_three(db):
print('test_three')
db.append(5)
assert db == [5]
%%writefile pytest_db.py
import pytest
from time import sleep
# Tego kodu nie modyfiklujemy!
def create_database():
print("create_database")
sleep(2)
return []
def reset_database(db):
print("reset_database")
del db[:]
# Wasze fixtures
@pytest.fixture(scope='session')
def shared_db():
return create_database()
@pytest.fixture
def db(shared_db):
reset_database(shared_db)
return shared_db
# Wasze trzy testy: test_a, test_b i test_c
def test_one(db):
print('test_one')
db.append(2)
assert db == [2]
def test_two(db):
print('test_two')
assert db == []
def test_three(db):
print('test_three')
db.append(5)
assert db == [5]
Writing pytest_db.py
! pytest pytest_db.py -qsvv
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 3 items pytest_db.py::test_one create_database reset_database test_one PASSED pytest_db.py::test_two reset_database test_two PASSED pytest_db.py::test_three reset_database test_three PASSED ============================== 3 passed in 2.04s ===============================
Fixture Finalization¶
%%writefile pytest_fixture_finalization.py
import pytest
@pytest.fixture
def csv_file_stream():
stream = open("people.csv")
print('stream opened')
yield stream
# no need for try-finally
stream.close()
print('stream closed')
def test_people(csv_file_stream):
print('running test')
assert False
# def test_people1(csv_file_stream):
# print('running test')
# assert False
Overwriting pytest_fixture_finalization.py
! pytest pytest_fixture_finalization.py -q
F [100%] =================================== FAILURES =================================== _________________________________ test_people __________________________________ csv_file_stream = <_io.TextIOWrapper name='people.csv' mode='r' encoding='UTF-8'> def test_people(csv_file_stream): print('running test') > assert False E assert False pytest_fixture_finalization.py:15: AssertionError ---------------------------- Captured stdout setup ----------------------------- stream opened ----------------------------- Captured stdout call ----------------------------- running test --------------------------- Captured stdout teardown --------------------------- stream closed =========================== short test summary info ============================ FAILED pytest_fixture_finalization.py::test_people - assert False 1 failed in 0.38s
%%writefile pytest_fixture_finalization_2.py
import pytest
@pytest.fixture
def a(b):
print('setup a')
yield 42
print('teardown a')
@pytest.fixture
def b():
print('setup b')
yield 22
print('teardown b')
@pytest.fixture
def c():
print('setup c')
yield 22
print('teardown c')
def test_people(a):
print('running test')
assert False
Overwriting pytest_fixture_finalization_2.py
! pytest pytest_fixture_finalization_2.py -q
F [100%] =================================== FAILURES =================================== _________________________________ test_people __________________________________ a = 42 def test_people(a): print('running test') > assert False E assert False pytest_fixture_finalization_2.py:24: AssertionError ---------------------------- Captured stdout setup ----------------------------- setup b setup a ----------------------------- Captured stdout call ----------------------------- running test --------------------------- Captured stdout teardown --------------------------- teardown a teardown b =========================== short test summary info ============================ FAILED pytest_fixture_finalization_2.py::test_people - assert False 1 failed in 0.50s
Exercise: 🏠 Advanced Fixture Creation and Verification¶
You are building a test suite for an application that processes data using a shared resource. Your task is to:
Implement a custom fixture that:
- Sets up a
ResourceManager
object before tests run. - The
ResourceManager
should:- Keep track of how many times it has been initialized.
- Provide an
increment
method to increase an internal counter. - Provide a
value
method to return the current counter value.
- Implement teardown logic that resets the counter to zero after tests are complete.
- Sets up a
Write test cases to:
- Verify that the fixture is created only once when using
session
scope. - Confirm that the internal counter is shared across tests when using
session
scope. - Ensure that the counter value is as expected after each test.
- Test that the teardown logic properly resets the counter after all tests have run.
- Verify that the fixture is created only once when using
Challenge:
- Modify the fixture to use
function
scope and update the tests accordingly to reflect the change in behavior. - Ensure that with
function
scope, the fixture is created anew for each test, and the counter does not retain its value between tests.
- Modify the fixture to use
Initial Code¶
# %%writefile test_resource_manager.py
import pytest
class ResourceManager:
initialization_count = 0
def __init__(self):
type(self).initialization_count += 1
self.counter = 0
def increment(self):
self.counter += 1
def value(self):
return self.counter
def reset(self):
self.counter = 0
# TODO: Implement the fixture
# @pytest.fixture(scope='session')
# def resource_manager():
# pass
# TODO: Write the tests
# def test_increment_1(resource_manager):
# pass
# def test_increment_2(resource_manager):
# pass
# def test_initialization_count():
# pass
# def test_teardown(resource_manager):
# pass
Solution¶
%%writefile test_resource_manager.py
import pytest
class ResourceManager:
initialization_count = 0
def __init__(self):
type(self).initialization_count += 1
self.counter = 0
def increment(self):
self.counter += 1
def value(self):
return self.counter
def reset(self):
self.counter = 0
# Implement the fixture with 'session' scope
@pytest.fixture(scope='session')
def resource_manager():
"""
Fixture that provides a ResourceManager instance.
"""
rm = ResourceManager()
yield rm
# Teardown logic
rm.reset()
def test_increment_1(resource_manager):
"""
Test that increments the counter and checks its value.
"""
resource_manager.increment()
assert resource_manager.value() == 1, "Counter should be 1 after first increment"
def test_increment_2(resource_manager):
"""
Another test that increments the counter and checks its value.
"""
resource_manager.increment()
assert resource_manager.value() == 2, "Counter should be 2 after second increment"
def test_initialization_count():
"""
Test that the ResourceManager was initialized only once.
"""
assert ResourceManager.initialization_count == 1, f"Initialization count should be 1, got {ResourceManager.initialization_count}"
def test_teardown(resource_manager):
"""
Test that verifies the teardown logic reset the counter.
"""
# Teardown occurs after the last test using the fixture
# Since we're in the last test, the counter should still be 2
assert resource_manager.value() == 2, "Counter should be 2 before teardown"
# After the test, the teardown will reset the counter
Writing test_resource_manager.py
! pytest test_resource_manager.py -qsvv
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 4 items test_resource_manager.py::test_increment_1 PASSED test_resource_manager.py::test_increment_2 PASSED test_resource_manager.py::test_initialization_count PASSED test_resource_manager.py::test_teardown PASSED ============================== 4 passed in 0.10s ===============================
%%writefile test_resource_manager.py
import pytest
class ResourceManager:
initialization_count = 0
def __init__(self):
type(self).initialization_count += 1
self.counter = 0
def increment(self):
self.counter += 1
def value(self):
return self.counter
def reset(self):
self.counter = 0
# Challenge: Modify the fixture to 'function' scope
@pytest.fixture(scope='function')
def resource_manager_function():
"""
Fixture that provides a new ResourceManager instance for each test.
"""
rm = ResourceManager()
yield rm
# Teardown logic
rm.reset()
def test_increment_function_1(resource_manager_function):
"""
Test that increments the counter and checks its value with function scope.
"""
resource_manager_function.increment()
assert resource_manager_function.value() == 1, "Counter should be 1 after increment in function-scoped fixture"
def test_increment_function_2(resource_manager_function):
"""
Another test that increments the counter and checks its value with function scope.
"""
resource_manager_function.increment()
assert resource_manager_function.value() == 1, "Counter should be 1 in a new instance of function-scoped fixture"
def test_initialization_count_function():
"""
Test that the ResourceManager was initialized multiple times.
"""
# Should be 3 initializations: 1 from session-scoped fixture, 2 from function-scoped fixtures
assert ResourceManager.initialization_count == 2, f"Initialization count should be 2, got {ResourceManager.initialization_count}"
def test_teardown_function(resource_manager_function):
"""
Test that verifies the teardown logic reset the counter in function scope.
"""
resource_manager_function.increment()
assert resource_manager_function.value() == 1, "Counter should be 1 before teardown in function-scoped fixture"
# After the test, the teardown will reset the counter
Overwriting test_resource_manager.py
! pytest test_resource_manager.py -qsvv
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 4 items test_resource_manager.py::test_increment_1 PASSED test_resource_manager.py::test_increment_2 PASSED test_resource_manager.py::test_initialization_count PASSED test_resource_manager.py::test_teardown PASSED ============================== 4 passed in 0.15s ===============================
%%writefile pytest_fixture_finalization.py
import pytest
@pytest.fixture
def csv_file_stream():
# stream = open("people.csv")
# print('stream opened')
# yield stream
# # no need for try-finally
# stream.close()
# print('stream closed')
# or simpler:
with open("people.csv") as stream:
print("stream opened")
yield stream
# implicit stream.close()
print("stream closed")
def test_people(csv_file_stream):
print('running test')
assert False
Overwriting pytest_fixture_finalization.py
! pytest pytest_fixture_finalization.py -q
F [100%] =================================== FAILURES =================================== _________________________________ test_people __________________________________ csv_file_stream = <_io.TextIOWrapper name='people.csv' mode='r' encoding='UTF-8'> def test_people(csv_file_stream): print('running test') > assert False E assert False pytest_fixture_finalization.py:22: AssertionError ---------------------------- Captured stdout setup ----------------------------- stream opened ----------------------------- Captured stdout call ----------------------------- running test --------------------------- Captured stdout teardown --------------------------- stream closed =========================== short test summary info ============================ FAILED pytest_fixture_finalization.py::test_people - assert False 1 failed in 0.53s
Parametrizing Fixtures¶
Parametrizing fixtures in Pytest allows you to run the same test function multiple times with different inputs. This is especially useful when you want to test the same functionality with various configurations or data sets without writing separate tests for each case.
@pytest.fixture(params=[...])
: Theparams
argument in the@pytest.fixture
decorator specifies a list of parameter values that the fixture will be called with.request.param
: Within the fixture function, Pytest provides a specialrequest
object, which has aparam
attribute. This attribute holds the current parameter value for each invocation of the fixture.
For each parameter value specified in params
, Pytest:
- Calls the fixture function with
request.param
set to that value. - Executes all tests that use this fixture once for each parameter value.
%%writefile people1.csv
first_name,last_name
John,Smith
Alice,Wilson
Overwriting people1.csv
%%writefile people2.csv
first_name,last_name
Overwriting people2.csv
%%writefile pytest_parametrized_fixtures.py
import pytest
@pytest.fixture(params=['people1.csv', 'people2.csv'])
def people_csv_stream(request):
print(f'opening {request.param}')
with open(request.param) as stream:
yield stream
def test_people(people_csv_stream):
print('test_people')
print(people_csv_stream.read())
Overwriting pytest_parametrized_fixtures.py
! pytest pytest_parametrized_fixtures.py -qs
opening people1.csv test_people first_name,last_name John,Smith Alice,Wilson .opening people2.csv test_people first_name,last_name . 2 passed in 0.07s
Factory Fixtures¶
In the previous section, we explored how to use parametrized fixtures to run the same test function multiple times with different inputs. While parametrized fixtures are powerful, they have limitations when you need to generate test data dynamically or when the inputs cannot be predetermined. This is where factory fixtures come into play.
Factory fixtures allow you to pass arguments to a fixture at test time, giving you greater flexibility in setting up your test data. They are particularly useful when you need to create multiple test data instances with varying attributes within a single test function.
- Factory Fixture: A fixture that returns a function (the factory) which can accept arguments to create test data dynamically.
- Why Use Factory Fixtures?
- Dynamic Data Creation: When test data cannot be predefined and needs to be generated on the fly.
- Flexibility: Allows tests to specify exactly what data they need.
- Reusability: Centralizes the data creation logic, making it reusable across different tests.
%%writefile pytest_factory_fixture.py
import pytest
class Customer:
def __init__(self, first_name, last_name, email, **kwargs):
self.first_name = first_name
self.last_name = last_name
self.email = email
for key, value in kwargs.items():
setattr(self, key, value)
# Factory fixture that creates Customer instances
@pytest.fixture
def make_customer():
def _make_customer(
first_name="John",
last_name="Doe",
email="john.doe@example.com",
**extra_attrs
):
customer = Customer(
first_name=first_name,
last_name=last_name,
email=email,
**extra_attrs
)
return customer
return _make_customer
def test_create_default_customer(make_customer):
customer = make_customer()
assert customer.first_name == "John"
assert customer.last_name == "Doe"
assert customer.email == "john.doe@example.com"
def test_create_custom_customer(make_customer):
customer = make_customer(first_name="Alice", last_name="Smith", email="alice.smith@example.com")
assert customer.first_name == "Alice"
assert customer.last_name == "Smith"
assert customer.email == "alice.smith@example.com"
def test_create_customer_with_extra_attrs(make_customer):
customer = make_customer(age=30, country="USA")
assert customer.age == 30
assert customer.country == "USA"
Overwriting pytest_factory_fixture.py
! pytest pytest_factory_fixture.py -qs
... 3 passed in 0.07s
Composing Fixtures¶
In Pytest, fixtures are a powerful way to manage test setup and teardown. One of the key strengths of fixtures is their ability to depend on other fixtures. This allows you to compose complex test scenarios by building upon simpler, reusable components. By composing fixtures, you can avoid redundant code and create more maintainable and scalable test suites.
Example: Composing Fixtures in an E-commerce Application¶
Suppose you're testing an e-commerce application with classes Customer
, Sale
, and Transaction
. Each of these classes depends on the others:
- A
Transaction
involves aSale
and aCustomer
. - A
Sale
is made by aCustomer
.
Instead of creating a large fixture that sets up everything at once, you can create individual fixtures for each component and compose them.
%%writefile pytest_composing_fixture.py
import pytest
class Customer:
def __init__(self, first_name, last_name, email):
self.first_name = first_name
self.last_name = last_name
self.email = email
class Sale:
def __init__(self, amount, sku, customer):
self.amount = amount
self.sku = sku
self.customer = customer
class Transaction:
def __init__(self, transaction_id, sale, customer):
self.transaction_id = transaction_id
self.sale = sale
self.customer = customer
@pytest.fixture
def make_customer():
def _make_customer(
first_name="John",
last_name="Doe",
email="john.doe@example.com",
**extra_attrs
):
customer = Customer(
first_name=first_name,
last_name=last_name,
email=email,
**extra_attrs
)
return customer
return _make_customer
@pytest.fixture
def make_sale(make_customer):
def _make_sale(amount=100.0, sku="ABC123", customer=None):
if customer is None:
customer = make_customer()
sale = Sale(amount=amount, sku=sku, customer=customer)
return sale
return _make_sale
@pytest.fixture
def make_transaction(make_sale, make_customer):
def _make_transaction(transaction_id, sale=None, customer=None):
if sale is None:
sale = make_sale(customer=customer)
if customer is None:
customer = sale.customer
transaction = Transaction(
transaction_id=transaction_id,
sale=sale,
customer=customer,
)
return transaction
return _make_transaction
def test_transaction_creation(make_transaction):
transaction = make_transaction(transaction_id="TXN1001")
assert transaction.transaction_id == "TXN1001"
assert transaction.sale.amount == 100.0
assert transaction.customer.first_name == "John"
def test_transaction_with_custom_customer(make_transaction, make_customer):
custom_customer = make_customer(first_name="Alice", last_name="Smith")
transaction = make_transaction(transaction_id="TXN1002", customer=custom_customer)
assert transaction.customer.first_name == "Alice"
assert transaction.customer.last_name == "Smith"
Exercise: 🏠 Advanced Factory and Composed Fixtures¶
Description¶
You are tasked with testing a complex Library Management System. The system consists of several interconnected classes:
Book
: Represents a book with attributes such astitle
,author
,isbn
, and availability status.Member
: Represents a library member with attributes likename
,member_id
, and a list ofborrowed_books
.Library
: Manages collections of books and members, and provides methods for borrowing and returning books.
Your objectives are:
- Create factory fixtures for
Book
andMember
that allow dynamic creation of instances with customizable attributes. - Compose fixtures to create a
Library
instance that depends on theBook
andMember
fixtures. - Write tests that:
- Verify that a member can borrow a book successfully.
- Ensure that a book's availability status updates correctly when borrowed and returned.
- Confirm that a member cannot borrow more books than the allowed limit.
- Implement fixture scopes appropriately to optimize test performance and ensure proper isolation between tests.
- Challenge: Modify the fixtures to handle parameterization, allowing tests to run with different configurations (e.g., varying the maximum number of books a member can borrow).
Initial Code¶
# %%writefile test_library_system.py
import pytest
class Book:
def __init__(self, title, author, isbn):
self.title = title
self.author = author
self.isbn = isbn
self.is_available = True
class Member:
def __init__(self, name, member_id, max_books=3):
self.name = name
self.member_id = member_id
self.borrowed_books = []
self.max_books = max_books
class Library:
def __init__(self):
self.books = []
self.members = []
def add_book(self, book):
# Add book to the library collection
pass # TODO: Implement this method
def register_member(self, member):
# Register a new library member
pass # TODO: Implement this method
def borrow_book(self, member_id, isbn):
# Member borrows a book
pass # TODO: Implement this method
def return_book(self, member_id, isbn):
# Member returns a book
pass # TODO: Implement this method
# TODO: Implement factory fixtures for Book and Member
# @pytest.fixture
# def make_book():
# pass
# @pytest.fixture
# def make_member():
# pass
# TODO: Implement a composed fixture for Library
# @pytest.fixture
# def library(make_book, make_member):
# pass
# TODO: Write tests using the fixtures
# def test_member_can_borrow_book(library, make_book, make_member):
# pass
# def test_book_availability_updates(library, make_book, make_member):
# pass
# def test_member_borrow_limit(library, make_book, make_member):
# pass
Solution¶
%%writefile test_library_system.py
import pytest
class Book:
def __init__(self, title, author, isbn):
self.title = title
self.author = author
self.isbn = isbn
self.is_available = True
class Member:
def __init__(self, name, member_id, max_books=3):
self.name = name
self.member_id = member_id
self.borrowed_books = []
self.max_books = max_books
class Library:
def __init__(self):
self.books = {}
self.members = {}
def add_book(self, book):
self.books[book.isbn] = book
def register_member(self, member):
self.members[member.member_id] = member
def borrow_book(self, member_id, isbn):
member = self.members.get(member_id)
book = self.books.get(isbn)
if not member or not book:
return False # or raise exception
if not book.is_available:
return False # or raise exception
if len(member.borrowed_books) >= member.max_books:
return False # or raise exception
book.is_available = False
member.borrowed_books.append(book)
return True
def return_book(self, member_id, isbn):
member = self.members.get(member_id)
book = self.books.get(isbn)
if not member or not book:
return False # or raise exception
if book not in member.borrowed_books:
return False # or raise exception
book.is_available = True
member.borrowed_books.remove(book)
return True
# Factory fixture for Book
@pytest.fixture(scope='session')
def make_book():
def _make_book(title="Default Title", author="Default Author", isbn=None):
if isbn is None:
import uuid
isbn = str(uuid.uuid4())
return Book(title=title, author=author, isbn=isbn)
return _make_book
# Factory fixture for Member
@pytest.fixture(scope='session')
def make_member():
def _make_member(name="Default Name", member_id=None, max_books=3):
if member_id is None:
import uuid
member_id = str(uuid.uuid4())
return Member(name=name, member_id=member_id, max_books=max_books)
return _make_member
# Composed fixture for Library
@pytest.fixture
def library(make_book, make_member):
lib = Library()
# Add some default books and members
book1 = make_book(title="Book One", author="Author A")
book2 = make_book(title="Book Two", author="Author B")
lib.add_book(book1)
lib.add_book(book2)
member1 = make_member(name="Member One")
member2 = make_member(name="Member Two")
lib.register_member(member1)
lib.register_member(member2)
return lib
def test_member_can_borrow_book(library, make_book, make_member):
member = make_member(name="Test Member")
library.register_member(member)
book = make_book(title="Test Book")
library.add_book(book)
assert library.borrow_book(member.member_id, book.isbn) == True # this line is optional in term of the test scope
assert book in member.borrowed_books
assert not book.is_available
def test_book_availability_updates(library, make_book, make_member):
member = make_member(name="Test Member")
library.register_member(member)
book = make_book(title="Test Book")
library.add_book(book)
# Borrow the book
library.borrow_book(member.member_id, book.isbn)
assert not book.is_available
# Return the book
library.return_book(member.member_id, book.isbn)
assert book.is_available
assert book not in member.borrowed_books
def test_member_borrow_limit(library, make_book, make_member):
member = make_member(name="Test Member", max_books=2)
library.register_member(member)
books = [make_book(title=f"Book {i}") for i in range(3)]
for book in books:
library.add_book(book)
# Borrow first book
assert library.borrow_book(member.member_id, books[0].isbn) == True
# Borrow second book
assert library.borrow_book(member.member_id, books[1].isbn) == True
# Attempt to borrow third book should fail
assert library.borrow_book(member.member_id, books[2].isbn) == False
assert len(member.borrowed_books) == 2
# Challenge: Parameterize the max_books limit
@pytest.fixture(params=[1, 2, 3])
def member_with_limit(make_member, request):
return make_member(max_books=request.param)
def test_member_borrow_limit_parametrized(library, make_book, member_with_limit):
"""
library -> object
make_book -> factory function
member_with_limit -> object
"""
library.register_member(member_with_limit)
books = [make_book(title=f"Book {i}") for i in range(5)]
for book in books:
library.add_book(book)
borrow_count = 0
for book in books:
success = library.borrow_book(member_with_limit.member_id, book.isbn)
if success:
borrow_count += 1
else:
break
assert borrow_count == member_with_limit.max_books
assert len(member_with_limit.borrowed_books) == member_with_limit.max_books
Overwriting test_library_system.py
! pytest test_library_system.py -sv
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 6 items test_library_system.py::test_member_can_borrow_book PASSED test_library_system.py::test_book_availability_updates PASSED test_library_system.py::test_member_borrow_limit PASSED test_library_system.py::test_member_borrow_limit_parametrized[1] PASSED test_library_system.py::test_member_borrow_limit_parametrized[2] PASSED test_library_system.py::test_member_borrow_limit_parametrized[3] PASSED ============================== 6 passed in 0.08s ===============================
More on pytest¶
Grouping Tests¶
When writing tests, it’s important to organize them in a way that makes them easy to understand, maintain, and extend. Pytest supports grouping tests by encapsulating related tests into classes. This approach helps structure your test suite, especially for large projects, by logically grouping tests that share a common purpose or context.
%%writefile pytest_grouped_tests.py
def div(a, b):
return a / b
def mul(a, b):
return a * b
class TestDiv:
def test_one(self):
assert div(5, 2) == 2.5
def test_two(self):
assert div(4, 2) == 2
class TestMul:
def test_one(self):
assert mul(2, 2) == 4
Overwriting pytest_grouped_tests.py
! pytest pytest_grouped_tests.py -q
... [100%] 3 passed in 0.03s
Skipping Tests¶
In some scenarios, you may want to conditionally skip certain tests or explicitly mark tests to be skipped. Pytest provides powerful decorators to skip tests based on conditions such as the Python version, operating system, or other custom logic. This feature is particularly useful when writing platform-specific tests or handling features that depend on external environments.
%%writefile pytest_skipping_tests.py
import sys
import pytest
@pytest.mark.skipif(sys.version_info < (3, 6), reason="requires python3.6 or higher") # true -> skip, false -> continue
def test_1():
print("Test launched only on Python 3.6 or higher")
@pytest.mark.skipif(sys.version_info > (3, 6), reason="requires python3.6 or older")
def test_2():
print("Test launched only on Python 3.6 and older")
@pytest.mark.skipif(sys.platform != 'win32', reason="requires windows")
def test_3():
print("Testing win32 specific features")
only_win32 = pytest.mark.skipif(sys.platform != 'win32', reason="requires windows")
@only_win32
def test_4():
print("Testing win32 specific features")
@only_win32
def test_5():
print("Testing win32 specific features")
@pytest.mark.skip()
def test_6():
print("Test never launached")
@pytest.mark.skipif(sys.platform != 'darwin', reason="requires MacOs")
def test_7():
print("Testing MacOs specific features")
Overwriting pytest_skipping_tests.py
! pytest pytest_skipping_tests.py -v
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 7 items pytest_skipping_tests.py::test_1 PASSED [ 14%] pytest_skipping_tests.py::test_2 SKIPPED (requires python3.6 or older) [ 28%] pytest_skipping_tests.py::test_3 SKIPPED (requires windows) [ 42%] pytest_skipping_tests.py::test_4 SKIPPED (requires windows) [ 57%] pytest_skipping_tests.py::test_5 SKIPPED (requires windows) [ 71%] pytest_skipping_tests.py::test_6 SKIPPED (unconditional skip) [ 85%] pytest_skipping_tests.py::test_7 PASSED [100%] ========================= 2 passed, 5 skipped in 0.10s =========================
Parametrized Tests¶
Parametrized tests in Pytest allow you to run the same test function multiple times with different sets of input data. This is a powerful feature for efficiently testing a wide range of input scenarios without duplicating code.
%%writefile pytest_parametrized_tests.py
import pytest
@pytest.mark.parametrize('number', [10, 20, 30])
@pytest.mark.parametrize('letter', ['a', 'b', 'c'])
def test_something(number, letter):
print(number, letter)
Overwriting pytest_parametrized_tests.py
! pytest pytest_parametrized_tests.py -v
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 9 items pytest_parametrized_tests.py::test_something[a-10] PASSED [ 11%] pytest_parametrized_tests.py::test_something[a-20] PASSED [ 22%] pytest_parametrized_tests.py::test_something[a-30] PASSED [ 33%] pytest_parametrized_tests.py::test_something[b-10] PASSED [ 44%] pytest_parametrized_tests.py::test_something[b-20] PASSED [ 55%] pytest_parametrized_tests.py::test_something[b-30] PASSED [ 66%] pytest_parametrized_tests.py::test_something[c-10] PASSED [ 77%] pytest_parametrized_tests.py::test_something[c-20] PASSED [ 88%] pytest_parametrized_tests.py::test_something[c-30] PASSED [100%] ============================== 9 passed in 0.12s ===============================
%%writefile pytest_parametrized_tests_1.py
import pytest
# Function to calculate area of a rectangle
def calculate_area(length, width):
return length * width
@pytest.mark.parametrize('length,width,expected', [
(5, 10, 50), # Case 1: Normal rectangle
(0, 10, 0), # Case 2: Zero length
(5, 0, 0), # Case 3: Zero width
(5, 5, 25), # Case 4: Square
])
def test_calculate_area(length, width, expected):
assert calculate_area(length, width) == expected
Overwriting pytest_parametrized_tests_1.py
! pytest pytest_parametrized_tests_1.py -v
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 4 items pytest_parametrized_tests_1.py::test_calculate_area[5-10-50] PASSED [ 25%] pytest_parametrized_tests_1.py::test_calculate_area[0-10-0] PASSED [ 50%] pytest_parametrized_tests_1.py::test_calculate_area[5-0-0] PASSED [ 75%] pytest_parametrized_tests_1.py::test_calculate_area[5-5-25] PASSED [100%] ============================== 4 passed in 0.16s ===============================
Exercise: 🏠 Dividing 2¶
Rewrite Tests for div as Parametrized Tests
Solution¶
%%writefile pytest_dividing_2.py
# Best solution
from pydantic import BaseModel
import math
from typing import Union, Optional, Type
import pytest
def div(a, b):
return a / b
class Dividable:
def __truediv__(self, other):
return 42
class MyClass:
pass
inf = float('inf')
nan = float('nan')
# Test cases for returned values
@pytest.mark.parametrize('a, b, expected', [
(4, 2, 2),
(3, 2, 1.5),
(-3, -2, 1.5),
(True, 2, 0.5),
(5e-324, 2, 0.0),
(1e308, 0.5, inf),
(Dividable(), Dividable(), 42),
(inf, 2, inf),
(inf, inf, nan),
])
def test_div_returned_values(a, b, expected):
got = div(a, b)
if math.isnan(expected):
assert math.isnan(got)
else:
assert got == expected
# Test cases for exceptions
@pytest.mark.parametrize('a, b, exception', [
(2, 0, ZeroDivisionError),
(inf, 0, ZeroDivisionError),
([3, 2, 5], [1, 2, 3], TypeError),
(MyClass(), MyClass(), TypeError),
])
def test_div_exceptions(a, b, exception):
with pytest.raises(exception):
div(a, b)
# Test cases for type of returned values
@pytest.mark.parametrize('a, b, expected_type', [
(4, 2, float),
(3, 2, float),
(-3, -2, float),
(True, 2, float),
(5e-324, 2, float),
(1e308, 0.5, float),
(Dividable(), Dividable(), int),
(inf, 2, float),
(inf, inf, float),
])
def test_div_returned_type(a, b, expected_type):
got = div(a, b)
assert isinstance(got, expected_type)
Overwriting pytest_dividing_2.py
%%writefile pytest_dividing_2.py
# shows hwo to use an objects as parameters
# shows hwo to generate test description dynamically
from pydantic import BaseModel
import math
from typing import Union, Optional, Type
import pytest
def div(a, b):
return a / b
# pydantic2 (most optimal) -> dataclass -> pydantic1
class Case(BaseModel):
a: object
b: object
expected: Optional[Union[float, int]] = None
exception: Optional[Type[Exception]] = None
doc: Optional[str] = None
class Dividable:
def __truediv__(self, other):
return 42
def __str__(self):
return 'Dividable'
class MyClass:
pass
inf = float('inf')
nan = float('nan')
cases = [
Case(a=4, b=2, expected=2),
Case(a=3, b=2, expected=1.5, doc="Dividing two integers should give a float"),
Case(a=-3, b=-2, expected=1.5),
Case(a=True, b=2, expected=0.5),
Case(a=5e-324, b=2, expected=0.0),
Case(a=1e308, b=0.5, expected=inf),
Case(a=Dividable(), b=Dividable(), expected=42),
Case(a=inf, b=2, expected=inf),
Case(a=inf, b=inf, expected=nan),
Case(a=2, b=0, exception=ZeroDivisionError),
Case(a=inf, b=0, exception=ZeroDivisionError),
Case(a=[3, 2, 5], b=[1, 2, 3], exception=TypeError),
Case(a=MyClass(), b=MyClass(), exception=TypeError),
]
# @pytest.mark.parametrize('case', cases)
@pytest.mark.parametrize('case', cases, ids=lambda case: case.doc or f'div({case.a}, {case.b})')
def test_div_cases(case: Case):
if case.exception is not None:
with pytest.raises(case.exception):
div(case.a, case.b)
else:
got = div(case.a, case.b)
if math.isnan(case.expected):
assert math.isnan(got)
else:
assert got == case.expected
def test_dividing_two_integers_returns_a_float_even_when_result_is_integer():
assert isinstance(div(4, 2), float)
Overwriting pytest_dividing_2.py
! pytest pytest_dividing_2.py -v
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 14 items pytest_dividing_2.py::test_div_cases[div(4, 2)] PASSED [ 7%] pytest_dividing_2.py::test_div_cases[Dividing two integers should give a float] PASSED [ 14%] pytest_dividing_2.py::test_div_cases[div(-3, -2)] PASSED [ 21%] pytest_dividing_2.py::test_div_cases[div(True, 2)] PASSED [ 28%] pytest_dividing_2.py::test_div_cases[div(5e-324, 2)] PASSED [ 35%] pytest_dividing_2.py::test_div_cases[div(1e+308, 0.5)] PASSED [ 42%] pytest_dividing_2.py::test_div_cases[div(Dividable, Dividable)] PASSED [ 50%] pytest_dividing_2.py::test_div_cases[div(inf, 2)] PASSED [ 57%] pytest_dividing_2.py::test_div_cases[div(inf, inf)] PASSED [ 64%] pytest_dividing_2.py::test_div_cases[div(2, 0)] PASSED [ 71%] pytest_dividing_2.py::test_div_cases[div(inf, 0)] PASSED [ 78%] pytest_dividing_2.py::test_div_cases[div([3, 2, 5], [1, 2, 3])] PASSED [ 85%] pytest_dividing_2.py::test_div_cases[div(<pytest_dividing_2.MyClass object at 0x10e5479e0>, <pytest_dividing_2.MyClass object at 0x10e547650>)] PASSED [ 92%] pytest_dividing_2.py::test_dividing_two_integers_returns_a_float_even_when_result_is_integer PASSED [100%] ============================== 14 passed in 0.60s ==============================
Launch Test Programatically¶
import os
import pytest
os.chdir('.')
r = pytest.main(args=['-s'])
print(type(r))
print(r)
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 20 items katta/test_bowling_game.py ..... test_capsys.py out asdf . test_fixture_scope_session1.py Setting up resource for the session .. test_fixture_scope_session2.py .. test_library_system.py ...... test_resource_manager.py ..F. =================================== FAILURES =================================== __________________________ test_initialization_count ___________________________ def test_initialization_count(): """ Test that the ResourceManager was initialized only once. """ > assert ResourceManager.initialization_count == 1, f"Initialization count should be 1, got {ResourceManager.initialization_count}" E AssertionError: Initialization count should be 1, got 4 E assert 4 == 1 E + where 4 = ResourceManager.initialization_count test_resource_manager.py:48: AssertionError =========================== short test summary info ============================ FAILED test_resource_manager.py::test_initialization_count - AssertionError: Initialization count should be 1, got 4 ========================= 1 failed, 19 passed in 0.26s ========================= <enum 'ExitCode'> 1
Code Coverage¶
%%writefile abs_util.py
def absolute(x):
if x >= 0:
return x
else:
return -x
Writing abs_util.py
%%writefile pytest_coverage.py
from abs_util import absolute
def test_absolute():
assert absolute(-2) == 2
Writing pytest_coverage.py
Coverage Package¶
! pip install coverage
Requirement already satisfied: coverage in ./.venv/lib/python3.12/site-packages (7.6.7)
! coverage run --source=. -m pytest pytest_coverage.py -q
. [100%] 1 passed in 0.12s
! coverage report
Name Stmts Miss Cover ------------------------------------------------------ abs_util.py 4 1 75% calculator.py 19 19 0% conftest.py 5 2 60% doctests.py 11 11 0% mathutils.py 6 6 0% pytest_composing_fixture.py 50 50 0% pytest_coverage.py 3 0 100% pytest_db.py 27 27 0% pytest_div.py 48 48 0% pytest_dividing_2.py 33 33 0% pytest_factory_fixture.py 28 28 0% pytest_factory_fixture_1.py 28 28 0% pytest_fixture_finalization.py 10 10 0% pytest_fixture_finalization_2.py 19 19 0% pytest_fixture_scope_class1.py 14 14 0% pytest_fixture_scope_class.py 12 12 0% pytest_fixture_scope_function.py 11 11 0% pytest_fixture_scope_module.py 9 9 0% pytest_fixture_scope_session1.py 6 6 0% pytest_fixture_scope_session2.py 6 6 0% pytest_fixture_scope_session.py 9 9 0% pytest_fundamentals.py 11 11 0% pytest_grouped_tests.py 14 14 0% pytest_implementing_fixtures.py 15 15 0% pytest_parametrized_fixtures.py 9 9 0% pytest_parametrized_tests.py 5 5 0% pytest_parametrized_tests_1.py 6 6 0% pytest_shared_fixture.py 15 15 0% pytest_skipping_tests.py 24 24 0% pytest_suppressoutput.py 33 33 0% pytest_tmpdir.py 5 5 0% recently_used_list.py 14 14 0% sftp.py 13 13 0% shape.py 11 11 0% suppress_output_module.py 9 9 0% test.py 11 11 0% test_capsys.py 5 5 0% test_fixture_scope_session1.py 5 5 0% test_fixture_scope_session2.py 5 5 0% test_library_system.py 115 115 0% test_resource_manager.py 27 27 0% ------------------------------------------------------ TOTAL 710 701 1%
! coverage html
Wrote HTML report to ]8;;file:///Users/a563420/python_training/testing/htmlcov/index.htmlhtmlcov/index.html]8;;
Exercise: 🏠 Coverage of suppress_output¶
Measure code coverage on suppress_output.
### pytest-cov package
! pip install pytest-cov
Requirement already satisfied: pytest-cov in ./.venv/lib/python3.12/site-packages (6.0.0) Requirement already satisfied: pytest>=4.6 in ./.venv/lib/python3.12/site-packages (from pytest-cov) (8.2.2) Requirement already satisfied: coverage>=7.5 in ./.venv/lib/python3.12/site-packages (from coverage[toml]>=7.5->pytest-cov) (7.6.7) Requirement already satisfied: iniconfig in ./.venv/lib/python3.12/site-packages (from pytest>=4.6->pytest-cov) (2.0.0) Requirement already satisfied: packaging in ./.venv/lib/python3.12/site-packages (from pytest>=4.6->pytest-cov) (24.1) Requirement already satisfied: pluggy<2.0,>=1.5 in ./.venv/lib/python3.12/site-packages (from pytest>=4.6->pytest-cov) (1.5.0)
! pytest --cov=. pytest_coverage.py -q
. [100%] ---------- coverage: platform darwin, python 3.12.6-final-0 ---------- Name Stmts Miss Cover ------------------------------------------------------ abs_util.py 4 1 75% calculator.py 19 19 0% conftest.py 5 2 60% doctests.py 11 11 0% mathutils.py 6 6 0% pytest_composing_fixture.py 50 50 0% pytest_coverage.py 3 0 100% pytest_db.py 27 27 0% pytest_div.py 48 48 0% pytest_dividing_2.py 33 33 0% pytest_factory_fixture.py 28 28 0% pytest_factory_fixture_1.py 28 28 0% pytest_fixture_finalization.py 10 10 0% pytest_fixture_finalization_2.py 19 19 0% pytest_fixture_scope_class1.py 14 14 0% pytest_fixture_scope_class.py 12 12 0% pytest_fixture_scope_function.py 11 11 0% pytest_fixture_scope_module.py 9 9 0% pytest_fixture_scope_session1.py 6 6 0% pytest_fixture_scope_session2.py 6 6 0% pytest_fixture_scope_session.py 9 9 0% pytest_fundamentals.py 11 11 0% pytest_grouped_tests.py 14 14 0% pytest_implementing_fixtures.py 15 15 0% pytest_parametrized_fixtures.py 9 9 0% pytest_parametrized_tests.py 5 5 0% pytest_parametrized_tests_1.py 6 6 0% pytest_shared_fixture.py 15 15 0% pytest_skipping_tests.py 24 24 0% pytest_suppressoutput.py 33 33 0% pytest_tmpdir.py 5 5 0% recently_used_list.py 14 14 0% sftp.py 13 13 0% shape.py 11 11 0% suppress_output_module.py 9 9 0% test.py 11 11 0% test_capsys.py 5 5 0% test_fixture_scope_session1.py 5 5 0% test_fixture_scope_session2.py 5 5 0% test_library_system.py 115 115 0% test_resource_manager.py 27 27 0% ------------------------------------------------------ TOTAL 710 701 1% 1 passed in 0.46s
! pytest --cov=. --cov-report=html pytest_coverage.py
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 1 item pytest_coverage.py . [100%] ---------- coverage: platform darwin, python 3.12.6-final-0 ---------- Coverage HTML written to dir htmlcov ============================== 1 passed in 1.40s ===============================
! pytest --cov=. --cov-report=term --cov-report=html pytest_coverage.py
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 1 item pytest_coverage.py . [100%] ---------- coverage: platform darwin, python 3.12.6-final-0 ---------- Name Stmts Miss Cover ------------------------------------------------------ abs_util.py 4 1 75% calculator.py 19 19 0% conftest.py 5 2 60% doctests.py 11 11 0% mathutils.py 6 6 0% pytest_composing_fixture.py 50 50 0% pytest_coverage.py 3 0 100% pytest_db.py 27 27 0% pytest_div.py 48 48 0% pytest_dividing_2.py 33 33 0% pytest_factory_fixture.py 28 28 0% pytest_factory_fixture_1.py 28 28 0% pytest_fixture_finalization.py 10 10 0% pytest_fixture_finalization_2.py 19 19 0% pytest_fixture_scope_class1.py 14 14 0% pytest_fixture_scope_class.py 12 12 0% pytest_fixture_scope_function.py 11 11 0% pytest_fixture_scope_module.py 9 9 0% pytest_fixture_scope_session1.py 6 6 0% pytest_fixture_scope_session2.py 6 6 0% pytest_fixture_scope_session.py 9 9 0% pytest_fundamentals.py 11 11 0% pytest_grouped_tests.py 14 14 0% pytest_implementing_fixtures.py 15 15 0% pytest_parametrized_fixtures.py 9 9 0% pytest_parametrized_tests.py 5 5 0% pytest_parametrized_tests_1.py 6 6 0% pytest_shared_fixture.py 15 15 0% pytest_skipping_tests.py 24 24 0% pytest_suppressoutput.py 33 33 0% pytest_tmpdir.py 5 5 0% recently_used_list.py 14 14 0% sftp.py 13 13 0% shape.py 11 11 0% suppress_output_module.py 9 9 0% test.py 11 11 0% test_capsys.py 5 5 0% test_fixture_scope_session1.py 5 5 0% test_fixture_scope_session2.py 5 5 0% test_library_system.py 115 115 0% test_resource_manager.py 27 27 0% ------------------------------------------------------ TOTAL 710 701 1% Coverage HTML written to dir htmlcov ============================== 1 passed in 1.76s ===============================
Doctests¶
Doctests are a convenient way to embed tests in the docstrings of your functions, methods, or modules. They serve both as documentation and as a way to ensure that your code behaves as expected. When you run doctests, Python executes the code examples in the docstrings and checks whether the output matches the expected results.
%%writefile doctests.py
import doctest
import math
def factorial(n):
"""Returns the factorial of n (n!).
>>> factorial(3)
7
>>> factorial(30)
265252859812191058636308480000000
>>> factorial(-1)
Traceback (most recent call last):
...
ValueError: n must be >= 0
>>> [factorial(n)
... for n in range(6)]
[1, 1, 2, 6, 24, 120]
"""
if not n >= 0:
raise ValueError("n must be >= 0")
result = 1
factor = 2
while factor <= n:
result *= factor
factor += 1
return result
# if __name__ == "__main__":
# doctest.testmod()
Overwriting doctests.py
! python -m doctest doctests.py -v
Trying: factorial(3) Expecting: 6 ok Trying: factorial(30) Expecting: 265252859812191058636308480000000 ok Trying: factorial(-1) Expecting: Traceback (most recent call last): ... ValueError: n must be >= 0 ok Trying: [factorial(n) for n in range(6)] Expecting: [1, 1, 2, 6, 24, 120] ok 1 items had no tests: doctests 1 items passed all tests: 4 tests in doctests.factorial 4 tests in 2 items. 4 passed and 0 failed. Test passed.
Exercise: 🏠 Doctest: suppress_output¶
Rewrite tests for @suppress_output in docstring.
Solution¶
%%writefile suppress_output_doctest.py
from contextlib import contextmanager
import doctest
import io
import sys
@contextmanager
def suppress_output():
"""
Example Usage:
>>> print('this will be displayed')
this will be displayed
>>> with suppress_output():
... print("this won't be displayed")
>>> before = sys.stdout
>>> print('printed')
printed
>>> with suppress_output() as stream:
... inside = sys.stdout
... print('not printed')
>>> after = sys.stdout
>>> print('printed again')
printed again
>>> inside is before
False
>>> after is before
True
>>> stream is before
True
>>> before = sys.stdout
>>> with suppress_output():
... print("not printed")
... raise ZeroDivisionError
Traceback (most recent call last):
ZeroDivisionError
>>> after = sys.stdout
>>> after is before
True
>>> print('still works')
still works
"""
stdout = sys.stdout
sys.stdout = io.StringIO() # Python 3
# sys.stdout = io.BytesIO() # Python 2
try:
yield stdout
finally:
sys.stdout = stdout
if __name__ == "__main__":
doctest.testmod()
Writing suppress_output_doctest.py
! python -m doctest suppress_output_doctest.py -v
Trying: print('this will be displayed') Expecting: this will be displayed ok Trying: with suppress_output(): print("this won't be displayed") Expecting nothing ok Trying: before = sys.stdout Expecting nothing ok Trying: print('printed') Expecting: printed ok Trying: with suppress_output() as stream: inside = sys.stdout print('not printed') Expecting nothing ok Trying: after = sys.stdout Expecting nothing ok Trying: print('printed again') Expecting: printed again ok Trying: inside is before Expecting: False ok Trying: after is before Expecting: True ok Trying: stream is before Expecting: True ok Trying: before = sys.stdout Expecting nothing ok Trying: with suppress_output(): print("not printed") raise ZeroDivisionError Expecting: Traceback (most recent call last): ZeroDivisionError ok Trying: after = sys.stdout Expecting nothing ok Trying: after is before Expecting: True ok Trying: print('still works') Expecting: still works ok 1 items had no tests: suppress_output_doctest 1 items passed all tests: 15 tests in suppress_output_doctest.suppress_output 15 tests in 2 items. 15 passed and 0 failed. Test passed.
unittest
¶
The unittest
module is Python's built-in testing framework, offering a class-based approach to test organization and execution. While Pytest is more modern and flexible, unittest
serves as a reference for understanding foundational testing principles in Python. It’s especially useful for scenarios that benefit from strict class-based organization or when working within legacy codebases that already rely on this framework.
%%writefile unittest_demo.py
import unittest
class TestCase(unittest.TestCase):
def setUp(self):
# Setup logic executed before each test
self.stream = open('people.csv')
print('setUp')
def tearDown(self):
# Teardown logic executed after each test
self.stream.close()
print('tearDown')
def test_one(self):
print('test_one')
self.assertEqual(2 + 2, 4) # Basic assertion
def test_two(self):
print('test_two')
self.assertEqual(2 + 2, 4) # Basic assertion
if __name__ == "__main__":
unittest.main()
Writing unittest_demo.py
! python unittest_demo.py
setUp test_one tearDown .setUp test_two tearDown . ---------------------------------------------------------------------- Ran 2 tests in 0.001s OK
! pytest unittest_demo.py -v
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 2 items unittest_demo.py::TestCase::test_one PASSED [ 50%] unittest_demo.py::TestCase::test_two PASSED [100%] ============================== 2 passed in 0.10s ===============================
Key Components¶
setUp
:- Runs before each test method in the test case.
- Typically used to initialize resources (e.g., opening files, setting up mock data).
tearDown
:- Runs after each test method in the test case.
- Typically used to clean up resources (e.g., closing files, resetting states).
Test Methods:
- Methods starting with
test_
are considered test methods. - Use
assertEqual
and other assertion methods to validate test conditions.
- Methods starting with
Summary¶
The unittest
framework provides a powerful, structured approach to testing in Python. While it requires more boilerplate code compared to Pytest, it excels in scenarios where strict class-based organization and setup/teardown logic are required.
Refactoring in Python¶
Refactoring is the process of restructuring existing code without changing its external behavior. It's a disciplined way to clean up code that minimizes the chances of introducing bugs. Refactoring improves the design of software, makes it easier to understand, and can help find bugs or enhance performance.
In this section, we'll explore several refactoring techniques using Python examples. Each example will illustrate how to improve code readability, maintainability, and scalability. We'll start with a piece of code "Before" refactoring, discuss its limitations, and then present the "After" refactored version along with explanations of the improvements made.
General Steps for Refactoring in Python¶
Understand the Current Code:
- Analyze the existing functionality to fully understand its purpose and behavior.
- Identify any potential redundancies, code smells, or areas for improvement.
Write Comprehensive Tests Before Refactoring:
- Ensure the code is fully covered by unit tests that verify its current behavior.
- Tests should validate both the expected outputs and any side effects.
- Without tests, refactoring may inadvertently introduce bugs.
Identify the Target for Refactoring:
- Locate specific code fragments that can be improved, such as:
- Long or complex methods.
- Repeated patterns or code duplication.
- Code with unclear intent or excessive complexity.
- Locate specific code fragments that can be improved, such as:
Break Down Refactoring Tasks:
- Plan the steps needed to refactor the identified code without altering its functionality.
- Focus on incremental changes to minimize the risk of errors.
Refactor the Code Incrementally:
- Perform changes step by step, testing after each modification.
- Examples of refactoring techniques include:
- Extract Method: Break down large methods into smaller, focused methods.
- Introduce Variables: Replace complex expressions or repeated calculations with named variables.
- Replace Magic Numbers: Replace hard-coded values with constants or configuration options.
- Simplify Conditional Logic: Refactor complex conditionals into more readable structures.
Run Tests Frequently:
- After each refactoring step, run the tests to ensure the code still behaves as expected.
- If tests fail, investigate and fix the issue before proceeding.
Optimize and Clean Up:
- Review the refactored code to ensure it adheres to best practices, such as the Single Responsibility Principle or the DRY (Don't Repeat Yourself) principle.
- Remove unused code or redundant comments, and ensure the code is well-documented.
Commit Changes with Context:
- Clearly document the purpose and scope of the refactoring in your version control system.
- This helps future developers (and your future self) understand why the changes were made.
Perform a Code Review:
- Share the refactored code with peers for review.
- Gather feedback to ensure the changes align with team standards and project goals.
Deploy and Monitor:
- Once the refactored code passes all tests and reviews, deploy it to production.
- Monitor its behavior to confirm that no unexpected issues arise.
By following these general steps, you ensure that refactoring improves the code without compromising its functionality or introducing regressions. This process promotes cleaner, more maintainable, and scalable Python codebases.
Extract Method¶
Extract Method is a fundamental refactoring technique where a fragment of code from an existing method or function is moved into a new, separate method. This new method is given a descriptive name that clearly communicates its purpose. The original code segment is then replaced with a call to the new method. This process enhances code clarity, promotes reusability, and aligns with several SOLID principles of object-oriented design.
Purpose of Extract Method¶
- Improve Readability: Breaking down complex methods into smaller, well-named methods makes code easier to read and understand.
- Encourage Reuse: Extracting code into methods allows for reuse in other parts of the codebase, reducing duplication.
- Simplify Maintenance: Smaller methods are easier to test, debug, and maintain. Changes are localized, minimizing the impact on the rest of the system.
- Enhance Abstraction: High-level methods can orchestrate the flow by calling lower-level methods, promoting modularity.
Alignment with SOLID Principles¶
The Extract Method technique closely relates to several SOLID principles:
Single Responsibility Principle (SRP):
- Definition: A class or method should have one, and only one, reason to change.
- Alignment: By extracting methods, you ensure that each method has a single responsibility. This reduces complexity and makes the codebase more robust to changes.
Open/Closed Principle (OCP):
- Definition: Software entities should be open for extension but closed for modification.
- Alignment: Extracting methods allows you to extend functionality by adding new methods rather than modifying existing ones. This minimizes the risk of introducing bugs in tested code.
Liskov Substitution Principle (LSP):
- While not directly impacted by Extract Method, having well-defined methods helps in creating subclasses that can override methods without altering the expected behavior, supporting LSP.
Interface Segregation Principle (ISP):
- Definition: Many client-specific interfaces are better than one general-purpose interface.
- Alignment: By extracting methods, you effectively create smaller, focused interfaces (methods) that clients can use without depending on larger, monolithic methods.
Dependency Inversion Principle (DIP):
- Definition: Depend upon abstractions, not concretions.
- Alignment: Extract Method can help in isolating dependencies and promoting the use of abstractions within methods, thus supporting DIP.
When to Use Extract Method¶
- Long or Complex Methods: Simplify methods that are difficult to read or understand due to their length or complexity.
- Duplicate Code: Eliminate code duplication by extracting common code into a single method.
- Distinct Responsibilities: Separate concerns when a method performs multiple tasks.
- Complex Conditionals or Loops: Extract complex logic into well-named methods to clarify intent.
How to Perform Extract Method¶
- Identify the Code Fragment: Locate a section of code within a method that serves a distinct purpose or represents a logical unit.
- Create a New Method: Move the identified code into a new method, naming it to reflect its functionality.
- Replace the Original Code: In the original method, replace the code fragment with a call to the new method.
- Adjust Parameters and Returns: Ensure the new method has the necessary parameters and returns any required values.
- Test Thoroughly: Run existing tests to verify that the behavior remains unchanged, ensuring adherence to the Open/Closed Principle.
Exercise: 🏠 Calculator¶
Use the refactoring guideline to optimize the code. Focus on uisng Extract Method
.
%%writefile calculator.py
class Calculator:
def calculate(self, operation, a, b):
if operation == 'add':
result = a + b
print(f"The result of adding {a} and {b} is {result}")
return result
elif operation == 'subtract':
result = a - b
print(f"The result of subtracting {b} from {a} is {result}")
return result
elif operation == 'multiply':
result = a * b
print(f"The result of multiplying {a} and {b} is {result}")
return result
elif operation == 'divide':
if b == 0:
raise ValueError("Cannot divide by zero")
result = a / b
print(f"The result of dividing {a} by {b} is {result}")
return result
else:
raise ValueError(f"Unknown operation '{operation}'")
Overwriting calculator.py
%%writefile test_calculator.py
import pytest
from calculator import Calculator
def test_add():
calc = Calculator()
assert calc.calculate('add', 2, 3) == 5
def test_subtract():
calc = Calculator()
assert calc.calculate('subtract', 5, 2) == 3
def test_multiply():
calc = Calculator()
assert calc.calculate('multiply', 3, 4) == 12
def test_divide():
calc = Calculator()
assert calc.calculate('divide', 10, 2) == 5
def test_divide_by_zero():
calc = Calculator()
with pytest.raises(ValueError, match="Cannot divide by zero"):
calc.calculate('divide', 10, 0)
def test_unknown_operation():
calc = Calculator()
with pytest.raises(ValueError, match="Unknown operation 'modulo'"):
calc.calculate('modulo', 10, 2)
Writing test_calculator.py
! pytest test_calculator.py
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 6 items test_calculator.py ...... [100%] ============================== 6 passed in 0.10s ===============================
! coverage run -m pytest test_calculator.py -q
! coverage report
...... [100%] 6 passed in 0.26s Name Stmts Miss Cover ---------------------------------------- calculator.py 21 0 100% conftest.py 5 2 60% test_calculator.py 22 0 100% ---------------------------------------- TOTAL 48 2 96%
%%writefile calculator.py
class Calculator:
def calculate(self, operation, a, b):
if operation == 'add':
return self._add(a, b)
elif operation == 'subtract':
return self._subtract(a, b)
elif operation == 'multiply':
return self._multiply(a, b)
elif operation == 'divide':
return self._divide(a, b)
else:
raise ValueError(f"Unknown operation '{operation}'")
def _add(self, a, b):
result = a + b
print(f"The result of adding {a} and {b} is {result}")
return result
def _subtract(self, a, b):
result = a - b
print(f"The result of subtracting {b} from {a} is {result}")
return result
def _multiply(self, a, b):
result = a * b
print(f"The result of multiplying {a} and {b} is {result}")
return result
def _divide(self, a, b):
if b == 0:
raise ValueError("Cannot divide by zero")
result = a / b
print(f"The result of dividing {a} by {b} is {result}")
return result
Overwriting calculator.py
! coverage run -m pytest test_calculator.py -q
! coverage report
...... [100%] 6 passed in 0.17s Name Stmts Miss Cover ---------------------------------------- calculator.py 29 0 100% conftest.py 5 2 60% test_calculator.py 22 0 100% ---------------------------------------- TOTAL 56 2 96%
Example: Registration Form¶
import unittest
class RegistrationFormTests(unittest.TestCase):
def test_successful_submission(self):
form = self.app.get('/register').form
# Fill in the form
form['personal'] = 'John'
form['family'] = 'Smith'
form['email'] = 'john@smith.com'
form['location'] = 'Cracow'
form['country'] = 'PL'
form['agreed_to_code_of_conduct'] = True
form['recaptcha_response_field'] = 'PASSED'
response = form.submit()
self.assertEqual(response.status_code, 200)
# We've extracted the code responsible for filling in the form into a separate function `fill_in_form`.
class RegistrationFormTests(unittest.TestCase):
def test_successful_submission(self):
form = self.app.get('/register').form # form is a dictionary
fill_in_form(form)
response = form.submit()
self.assertEqual(response.status_code, 200)
def fill_in_form(form):
form['personal'] = 'John'
form['family'] = 'Smith'
form['email'] = 'john@smith.com'
form['location'] = 'Cracow'
form['country'] = 'PL'
form['agreed_to_code_of_conduct'] = True
form['recaptcha_response_field'] = 'PASSED'
Variables¶
Refactoring with the "Variables" technique involves simplifying complex expressions by introducing meaningful variables to represent intermediate values or logical components. This approach enhances readability, reduces redundancy, and makes the code easier to maintain and debug. It also helps clarify the intent of calculations and logic by breaking them into smaller, understandable parts.
Steps for Refactoring with Variables¶
Identify Complex Expressions:
- Locate sections of the code with:
- Nested or chained expressions.
- Repeated calculations.
- Hard-to-read logic.
- These are candidates for breaking into smaller, meaningful components.
- Locate sections of the code with:
Assign Intermediate Values to Variables:
- Extract repeated or complex expressions into descriptive variables.
- Ensure the variable names clearly convey the meaning of the value they represent.
Example: Replace:
total = price * (1 + tax_rate) - discount if price > threshold else price
With:
tax = price * tax_rate discounted_price = price - discount if price > threshold else price total = discounted_price + tax
the first expression covers issue
Replace Magic Numbers with Constants:
- Replace hard-coded numbers with named constants to clarify their purpose.
- This improves readability and makes future adjustments easier.
Example:
Replace:if price > 100: discount = price * 0.1
With:
DISCOUNT_THRESHOLD = 100 DISCOUNT_RATE = 0.1 if price > DISCOUNT_THRESHOLD: discount = price * DISCOUNT_RATE
Simplify Conditional Logic:
- Break down complex conditions into variables or helper functions with descriptive names.
Example:
Replace:if user.is_active and (user.age > 18 or user.has_permission): return "Allowed"
With:
is_adult = user.age > 18 is_allowed = user.is_active and (is_adult or user.has_permission) if is_allowed: return "Allowed"
Group Related Calculations:
- Combine calculations or logical steps into sequential, meaningful variables to highlight their relationships.
Example:
Replace:net_income = revenue - (expenses + taxes)
With:
total_costs = expenses + taxes net_income = revenue - total_costs
Validate Functionality with Tests:
- Ensure the behavior remains consistent after refactoring by writing or running unit tests.
- Cover all possible scenarios, including edge cases, to ensure the refactoring did not alter the logic.
Exercise: 🏠 Calculator 2¶
Use the refactoring guideline to optimize the code. Focus on uisng Variables Technique
.
%%writefile calculator.py
def calculate_subtotal():
print('Calculate subtotal')
return 400.0
def calculate_taxable_subtotal():
print('Calculate taxable subtotal')
return 300.0
def calculate_total():
print('Calculate total')
return (calculate_subtotal() + calculate_taxable_subtotal() * 0.15
- (calculate_subtotal() * 0.1 if calculate_subtotal() > 100 else 0))
# if __name__ == "__main__":
# total = calculate_total()
# print(f'Total: {total}')
Overwriting calculator.py
%%writefile test_calculator.py
# test_calculator.py
import pytest
from calculator import (
calculate_subtotal,
calculate_taxable_subtotal,
calculate_total,
)
def test_calculate_subtotal():
subtotal = calculate_subtotal()
assert subtotal == 400.0
def test_calculate_taxable_subtotal():
taxable_subtotal = calculate_taxable_subtotal()
assert taxable_subtotal == 300.0
def test_calculate_total():
total = calculate_total()
expected_subtotal = 400.0
expected_taxable_subtotal = 300.0
expected_discount = expected_subtotal * 0.1 if expected_subtotal > 100 else 0
expected_tax = expected_taxable_subtotal * 0.15
expected_total = expected_subtotal + expected_tax - expected_discount
assert total == expected_total
Overwriting test_calculator.py
! coverage run -m pytest test_calculator.py -q
! coverage report
... [100%] 3 passed in 0.06s Name Stmts Miss Cover ---------------------------------------- calculator.py 9 0 100% conftest.py 5 2 60% test_calculator.py 16 0 100% ---------------------------------------- TOTAL 30 2 93%
%%writefile calculator.py
"""
We've introduced variables to store intermediate results:
- **Variable Assignment**: Store the result of `calculate_subtotal()` in `subtotal`.
- **Simplified Logic**: Use an `if` statement to determine the discount.
- **Named Variables**: `discount` and `tax` make the code more understandable.
"""
def calculate_subtotal():
print('Calculate subtotal')
return 400.0
def calculate_taxable_subtotal():
print('Calculate taxable subtotal')
return 300.0
# Refactored function
def calculate_total():
print('Calculate total')
subtotal = calculate_subtotal()
if subtotal > 100:
discount = subtotal * 0.1
else:
discount = 0
tax = calculate_taxable_subtotal() * 0.15
return subtotal + tax - discount
if __name__ == "__main__":
total = calculate_total()
print(f'Total: {total}')
Overwriting calculator.py
! coverage run -m pytest test_calculator.py -q
! coverage report
... [100%] 3 passed in 0.06s Name Stmts Miss Cover ---------------------------------------- calculator.py 17 3 82% conftest.py 5 2 60% test_calculator.py 16 0 100% ---------------------------------------- TOTAL 38 5 87%
Exercise: 🏠 Recently Used List¶
%%writefile recently_used_list.py
class RecentlyUsedList(list):
def append(self, elem):
if elem in self:
self.remove(elem)
super().append(elem)
# if __name__ == "__main__":
# rul = RecentlyUsedList()
# rul.append('first')
# rul.append('second')
# rul.append('third')
# rul.append('second')
# print(rul)
# rul.extend(['a', 'b', 'first']) # (!)
# print(rul)
Overwriting recently_used_list.py
Explanation
The RecentlyUsedList
class is intended to keep track of recently used items, moving any existing items to the end when re-added. However, inheriting from list
introduces unintended behavior:
- Unintended Methods: Methods like
extend()
andinsert()
are available but may not behave correctly. - Violation of Liskov Substitution Principle: The subclass does not behave consistently with the superclass.
%%writefile test_recently_used_list.py
import pytest
from recently_used_list import RecentlyUsedList
def test_initialization():
rul = RecentlyUsedList()
assert str(rul) == "[]"
def test_append_new_element():
rul = RecentlyUsedList()
rul.append('a')
assert str(rul) == "['a']"
def test_append_existing_element():
rul = RecentlyUsedList()
rul.append('a')
rul.append('b')
rul.append('a')
assert str(rul) == "['b', 'a']"
def test_get_item():
rul = RecentlyUsedList()
rul.append('a')
rul.append('b')
assert rul[0] == 'a'
assert rul[1] == 'b'
def test_get_item_out_of_range():
rul = RecentlyUsedList()
rul.append('a')
with pytest.raises(IndexError):
_ = rul[1]
Overwriting test_recently_used_list.py
! coverage run -m pytest test_recently_used_list.py -q
! coverage report
..... [100%] 5 passed in 0.13s Name Stmts Miss Cover ------------------------------------------------ conftest.py 5 2 60% recently_used_list.py 5 0 100% test_recently_used_list.py 26 0 100% ------------------------------------------------ TOTAL 36 2 94%
%%writefile recently_used_list.py
class RecentlyUsedList:
def __init__(self):
self._list = []
def append(self, elem):
if elem in self._list:
self._list.remove(elem)
self._list.append(elem)
def __str__(self):
return str(self._list)
def __getitem__(self, index):
return self._list[index]
Overwriting recently_used_list.py
! coverage run -m pytest test_recently_used_list.py -q
! coverage report
..... [100%] 5 passed in 0.09s Name Stmts Miss Cover ------------------------------------------------ conftest.py 5 2 60% recently_used_list.py 11 0 100% test_recently_used_list.py 26 0 100% ------------------------------------------------ TOTAL 42 2 95%
Explanation
By switching to composition:
- Controlled Interface: Only the methods we define are available.
- Encapsulation: Internal implementation details are hidden.
- Flexibility: We can change how we store items without affecting external code.
Exercise: 🏠 Type Code¶
# %%writefile shape.py
class Shape:
def __init__(self, shape):
assert shape in ('circle', 'dot')
self.shape = shape
def draw(self):
if self.shape == 'circle':
print('o')
elif self.shape == 'dot':
print('.')
circle = Shape('circle')
circle.draw()
o
Explanation
The Shape
class uses a type code (shape
) to determine which shape to draw, using conditional logic.
Issues
- Limited Extensibility: Adding new shapes requires modifying the
Shape
class. - Violation of Open/Closed Principle: The class is not closed for modification.
- Inefficient Design: Uses
if
statements to handle different behaviors.
%%writefile test_shape.py
import pytest
from contextlib import redirect_stdout
import io
import sys
from shape import Shape
def _test_output(shape, expected_output):
f = io.StringIO()
with redirect_stdout(f):
shape.draw()
output = f.getvalue()
assert output == expected_output + '\n'
def test_circle():
_test_output(shape=Shape('circle'), expected_output='o')
def test_dot():
_test_output(shape=Shape('dot'), expected_output='.')
Overwriting test_shape.py
! coverage run -m pytest test_shape.py -q
! coverage report
.. [100%] 2 passed in 0.09s Name Stmts Miss Cover ----------------------------------- conftest.py 5 2 60% shape.py 9 0 100% test_shape.py 15 0 100% ----------------------------------- TOTAL 29 2 93%
%%writefile shape.py
class Circle:
def draw(self):
print('o')
class Dot:
def draw(self):
print('.')
# Shape factory function
shapes = {
'circle': Circle,
'dot': Dot,
}
def Shape(shape):
factory = shapes[shape]
return factory()
Overwriting shape.py
! coverage run -m pytest test_shape.py -q
! coverage report
.. [100%] 2 passed in 0.11s Name Stmts Miss Cover ----------------------------------- conftest.py 5 2 60% shape.py 10 0 100% test_shape.py 15 0 100% ----------------------------------- TOTAL 30 2 93%
%%writefile shape.py
class Shape:
def __new__(cls, shape):
factory = shapes[shape]
return super().__new__(factory)
class Circle(Shape):
def draw(self):
print('o')
class Dot(Shape):
def draw(self):
print('.')
# Shapes dictionary
shapes = {
'circle': Circle,
'dot': Dot,
}
Overwriting shape.py
! coverage run -m pytest test_shape.py -q
! coverage report
.. [100%] 2 passed in 0.12s Name Stmts Miss Cover ----------------------------------- conftest.py 5 2 60% shape.py 11 0 100% test_shape.py 15 0 100% ----------------------------------- TOTAL 31 2 94%
%%writefile shape.py
class Shape:
_shapes = {}
def __new__(cls, shape):
cls._register_subclasses()
if shape in cls._shapes:
factory = cls._shapes[shape]
return super().__new__(factory)
raise ValueError(f"Shape '{shape}' is not registered.")
@classmethod
def _register_subclasses(cls):
for subclass in cls.__subclasses__():
shape_name = subclass.__name__.lower()
cls._shapes[shape_name] = subclass
# bad usage
# shapes = { subclass.__name__.lower(), subclass for subclass in Shape.__subclasses__() }
class Circle(Shape):
def draw(self):
print('o')
class Dot(Shape):
def draw(self):
print('.')
Overwriting shape.py
! coverage run -m pytest test_shape.py -q
! coverage report
.. [100%] 2 passed in 0.14s Name Stmts Miss Cover ----------------------------------- conftest.py 5 2 60% shape.py 19 1 95% test_shape.py 15 0 100% ----------------------------------- TOTAL 39 3 92%
Explanation
- Overriding
__new__
: TheShape
class overrides the__new__
method to instantiate the correct subclass. - Inheritance Hierarchy:
Circle
andDot
inherit fromShape
, allowingisinstance
checks. - Unified Interface: All shapes are instances of
Shape
, maintaining consistency.
Conclusion¶
Refactoring is an essential practice for improving code quality, readability, and maintainability. By applying refactoring techniques, developers can transform complex or inefficient code into cleaner, more efficient, and more understandable versions.
In this training material, we've covered several key refactoring techniques:
- Extract Method: Simplifies code by moving repeated or complex code into separate methods.
- Variable Extraction: Improves readability and efficiency by storing intermediate results in variables.
- Composition Over Inheritance: Enhances flexibility and reduces unintended behaviors by using composition instead of inheritance.
- Replacing Type Code with Classes: Utilizes polymorphism to eliminate type codes and conditional logic, making the code more extensible and maintainable.
Understanding and applying these techniques allows developers to write better Python code that is easier to understand, test, and maintain. Refactoring should be a continuous part of the development process, helping teams deliver high-quality software.
Key Takeaways for Trainers¶
- Highlight the Importance of Refactoring: Emphasize how refactoring leads to better code quality without changing external behavior.
- Use Practical Examples: The provided before-and-after code snippets serve as concrete examples of how refactoring improves code.
- Encourage Best Practices: Stress the importance of readability, maintainability, and adherence to design principles like the Open/Closed Principle.
- Demonstrate Incremental Changes: Show how small, incremental refactorings can lead to significant improvements.
- Promote Continuous Refactoring: Encourage developers to make refactoring a regular part of their workflow.
By focusing on these points, trainers can effectively convey the value of refactoring and equip learners with practical skills to improve their codebases.
Test-Driven Development (TDD)¶
Introduction to Test-Driven Development (TDD)¶
Test-Driven Development (TDD) is a software development process where tests are written before writing the bare minimum of code required for the test to pass. It emphasizes writing small, incremental tests and code in cycles.
Why Use TDD?¶
- Improves Code Quality: By writing tests first, developers think about the design and requirements before implementation.
- Simplifies Debugging: When a test fails, it's easy to locate the issue since it likely lies in the code added since the last passing test.
- Documentation: Tests serve as documentation of the code's intended behavior.
- Refactoring Support: Since tests are in place, code can be refactored with confidence, ensuring existing functionality remains intact.
TDD Theory¶
The TDD Cycle¶
The TDD process can be broken down into the following steps:
- Write a Simplest Failing Test: Start by writing a test that fails because the functionality isn't implemented yet.
- Write the Simplest Code to Pass the Test: Write just enough code to make the failing test pass.
- Refactor Code and Tests: Improve the code's structure and readability without changing its external behavior, ensuring all tests still pass.
Writing the Simplest Failing Test¶
- Purpose: To define a new functionality or requirement.
- Approach:
- Write a test for a specific behavior you want to implement.
- Ensure the test fails initially to verify it's testing the right thing.
Writing the Simplest Code to Pass the Test¶
- Purpose: To implement just enough code to make the failing test pass.
- Approach:
- Focus on minimal implementation.
- Avoid adding any additional functionality beyond what's necessary for the test to pass.
Refactoring Code and Tests¶
- Purpose: To improve the codebase without altering its external behavior.
- Approach:
- Clean up code, eliminate duplication, and improve design.
- Refactor tests if necessary to improve clarity and maintainability.
- Ensure all tests continue to pass after refactoring.
TDD in Practice: The Bowling Game Kata (Step-by-Step Guide)¶
Overview¶
The Bowling Game Kata is a programming exercise designed to practice TDD. We'll implement a BowlingGame
class that can calculate the score of a bowling game based on rolls provided.
Bowling Scoring Rules¶
- A game consists of 10 frames.
- Each frame can be one of the following:
- Normal frame: The player has two attempts to knock down 10 pins.
- Spare: The player knocks down all 10 pins in two attempts. The score for that frame is 10 plus the number of pins knocked down in the next roll.
- Strike: The player knocks down all 10 pins on the first attempt. The score for that frame is 10 plus the number of pins knocked down in the next two rolls.
- In the 10th frame, if the player scores a spare or strike, they get extra rolls accordingly.
Step 1: Setting Up the Project¶
Before we start, ensure you have the following setup:
- Python 3.x installed.
- pytest installed (
pip install pytest
).
We'll create two files:
bowling_game.py
: Contains theBowlingGame
class.test_bowling_game.py
: Contains the tests for theBowlingGame
class.
Step 2: Write the First Failing Test¶
Test: Gutter Game (All Zeros)¶
Objective: Ensure that when the player rolls all zeros, the total score is zero.
Writing the Test¶
Create test_bowling_game.py
and add the following code:
# test_bowling_game.py
import pytest
from bowling_game import BowlingGame
@pytest.fixture
def game():
return BowlingGame()
def test_gitter_gane(game):
for _ in range(20):
game.roll(0)
assert game.score() == 0
Explanation:
- We use a fixture
game
to create a new instance ofBowlingGame
before each test. - In
test_gutter_game
, we simulate 20 rolls of 0 pins. - We assert that the final score should be 0.
Running the Test¶
Run the test with:
pytest test_bowling_game.py
Expected Result: The test should fail because BowlingGame
is not implemented yet.
!pytest katta/test_bowling_game.py
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 0 items / 1 error ==================================== ERRORS ==================================== _________________ ERROR collecting katta/test_bowling_game.py __________________ ImportError while importing test module '/Users/a563420/python_training/testing/katta/test_bowling_game.py'. Hint: make sure your test modules/packages have valid Python names. Traceback: ../../.pyenv/versions/3.12.6/lib/python3.12/importlib/__init__.py:90: in import_module return _bootstrap._gcd_import(name[level:], package, level) katta/test_bowling_game.py:2: in <module> from bowling_game import BowlingGame E ImportError: cannot import name 'BowlingGame' from 'bowling_game' (/Users/a563420/python_training/testing/katta/bowling_game.py) =========================== short test summary info ============================ ERROR katta/test_bowling_game.py !!!!!!!!!!!!!!!!!!!! Interrupted: 1 error during collection !!!!!!!!!!!!!!!!!!!! =============================== 1 error in 0.54s ===============================
Step 3: Implement Minimal Code to Pass the Test¶
Implementing BowlingGame
¶
Create bowling_game.py
with the minimal implementation:
class BowlingGame:
def roll(self, pins):
pass
def score(self):
return 0
``
**Explanation**:
- We define the `BowlingGame` class with two methods:
- `roll(pins)`: Records a roll. Currently, it does nothing.
- `score()`: Returns the total score. Currently, it always returns 0.
#### Running the Test Again
Run:
```bash
pytest test_bowling_game.py
Expected Result: The test should now pass.
!pytest katta/test_bowling_game.py -sv
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 1 item katta/test_bowling_game.py::test_gitter_gane PASSED ============================== 1 passed in 0.08s ===============================
Step 4: Write the Next Failing Test¶
Test: All Ones¶
Objective: Ensure that when the player rolls all ones, the total score is 20.
Writing the Test¶
Add the following test to test_bowling_game.py
:
def test_all_ones(game):
for _ in range(20):
game.roll(1)
assert game.score() == 20
Explanation:
- We simulate 20 rolls of 1 pin each.
- We assert that the final score should be 20.
Running the Test¶
Run:
pytest test_bowling_game.py
Expected Result: The test should fail because score()
always returns 0.
!pytest katta/test_bowling_game.py -sv
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 2 items katta/test_bowling_game.py::test_gitter_gane PASSED katta/test_bowling_game.py::test_all_ones FAILED =================================== FAILURES =================================== ________________________________ test_all_ones _________________________________ game = <bowling_game.BowlingGame object at 0x10be6e390> def test_all_ones(game): for _ in range(20): game.roll(1) > assert game.score() == 20 E assert 0 == 20 E + where 0 = <bound method BowlingGame.score of <bowling_game.BowlingGame object at 0x10be6e390>>() E + where <bound method BowlingGame.score of <bowling_game.BowlingGame object at 0x10be6e390>> = <bowling_game.BowlingGame object at 0x10be6e390>.score katta/test_bowling_game.py:17: AssertionError =========================== short test summary info ============================ FAILED katta/test_bowling_game.py::test_all_ones - assert 0 == 20 ========================= 1 failed, 1 passed in 0.44s ==========================
Step 5: Update the Code to Pass the Test¶
Modifying BowlingGame
¶
Update bowling_game.py
:
class BowlingGame:
def __init__(self):
self.rolls = []
def roll(self, pins):
self.rolls.append(pins)
def score(self):
return sum(self.rolls)
Explanation:
- In
__init__
, we initialize an empty listrolls
to keep track of all the rolls. - In
roll(pins)
, we append each roll to therolls
list. - In
score()
, we return the sum of all rolls.
Running the Test Again¶
Run:
pytest test_bowling_game.py
Expected Result: Both tests (test_gutter_game
and test_all_ones
) should now pass.
!pytest katta/test_bowling_game.py -sv
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 2 items katta/test_bowling_game.py::test_gitter_gane PASSED katta/test_bowling_game.py::test_all_ones PASSED ============================== 2 passed in 0.09s ===============================
Step 6: Write a Test for Spares¶
Test: One Spare¶
Objective: Ensure that a spare is scored correctly.
- When a spare is rolled, the frame score is 10 plus the number of pins knocked down in the next roll.
Writing the Test¶
Add the following test to test_bowling_game.py
:
def test_one_spare(game):
game.roll(5)
game.roll(5)
for _ in range(18):
game.roll(1)
assert game.score() == 29
Explanation:
- Rolls: 5, 5 (spare), then 3.
- The spare bonus is the next roll, which is 3.
- Frame score: 10 + 3 = 13.
- Total score: 13 (first frame) + 3 (second frame) = 16.
Running the Test¶
Run:
pytest test_bowling_game.py
Expected Result: The test should fail because the current implementation doesn't handle spares.
!pytest katta/test_bowling_game.py -sv
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 3 items katta/test_bowling_game.py::test_gitter_gane PASSED katta/test_bowling_game.py::test_all_ones PASSED katta/test_bowling_game.py::test_one_spare FAILED =================================== FAILURES =================================== ________________________________ test_one_spare ________________________________ game = <bowling_game.BowlingGame object at 0x110ed8b90> def test_one_spare(game): game.roll(5) game.roll(5) for _ in range(18): game.roll(1) > assert game.score() == 29 E assert 28 == 29 E + where 28 = <bound method BowlingGame.score of <bowling_game.BowlingGame object at 0x110ed8b90>>() E + where <bound method BowlingGame.score of <bowling_game.BowlingGame object at 0x110ed8b90>> = <bowling_game.BowlingGame object at 0x110ed8b90>.score katta/test_bowling_game.py:27: AssertionError =========================== short test summary info ============================ FAILED katta/test_bowling_game.py::test_one_spare - assert 28 == 29 ========================= 1 failed, 2 passed in 0.33s ==========================
Step 7: Update Code to Handle Spares¶
Modifying BowlingGame
¶
Update bowling_game.py
:
class BowlingGame:
def __init__(self):
self.rolls = []
def roll(self, pins):
self.rolls.append(pins)
def score(self):
total = 0
roll_index = 0
for frame in range(10):
if self.is_spare(roll_index):
total += 10 + self.rolls[roll_index +2]
roll_index += 2
else:
total += self.rolls[roll_index] + self.rolls[roll_index +1]
roll_index += 2
return total
def is_spare(self, roll_index):
return self.rolls[roll_index] + self.rolls[roll_index + 1] == 10
Explanation:
- We iterate over 10 frames.
is_spare(roll_index)
: Checks if the sum of two rolls is 10.- If it's a spare, we add the spare bonus (next roll).
- Otherwise, we sum the two rolls.
- We increment
roll_index
accordingly.
Running the Test Again¶
Run:
pytest test_bowling_game.py
Expected Result: The spare test should pass, along with the previous tests.
!pytest katta/test_bowling_game.py -sv
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 3 items katta/test_bowling_game.py::test_gitter_gane PASSED katta/test_bowling_game.py::test_all_ones PASSED katta/test_bowling_game.py::test_one_spare PASSED ============================== 3 passed in 0.09s ===============================
Step 8: Write a Test for Strikes¶
Test: One Strike¶
Objective: Ensure that a strike is scored correctly.
- When a strike is rolled, the frame score is 10 plus the number of pins knocked down in the next two rolls.
Writing the Test¶
Add the following test to test_bowling_game.py
:
def test_one_strike(game):
game.roll(10)
game.roll(3)
game.roll(2)
for _ in range(16):
game.roll(1)
assert game.score() == 36
Explanation:
- Rolls: 10 (strike), then 3, 4.
- The strike bonus is the next two rolls: 3 + 4.
- Frame score: 10 + 3 + 4 = 17.
- Second frame: 3 + 4 = 7.
- Total score: 17 + 7 = 24.
Running the Test¶
Run:
pytest test_bowling_game.py
Expected Result: The test should fail because the current implementation doesn't handle strikes.
!pytest katta/test_bowling_game.py -sv
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 4 items katta/test_bowling_game.py::test_gitter_gane PASSED katta/test_bowling_game.py::test_all_ones PASSED katta/test_bowling_game.py::test_one_spare PASSED katta/test_bowling_game.py::test_one_strike FAILED =================================== FAILURES =================================== _______________________________ test_one_strike ________________________________ game = <bowling_game.BowlingGame object at 0x103facbc0> def test_one_strike(game): game.roll(10) game.roll(3) game.roll(2) for _ in range(16): game.roll(1) > assert game.score() == 36 katta/test_bowling_game.py:38: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ katta/bowling_game.py:15: in score if self.is_spare(roll_index): _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ self = <bowling_game.BowlingGame object at 0x103facbc0>, roll_index = 18 def is_spare(self, roll_index): > return self.rolls[roll_index] + self.rolls[roll_index + 1] == 10 E IndexError: list index out of range katta/bowling_game.py:26: IndexError =========================== short test summary info ============================ FAILED katta/test_bowling_game.py::test_one_strike - IndexError: list index out of range ========================= 1 failed, 3 passed in 0.38s ==========================
Step 9: Update Code to Handle Strikes¶
Modifying BowlingGame
¶
Update bowling_game.py
:
class BowlingGame:
def __init__(self):
self.rolls = []
def roll(self, pins):
self.rolls.append(pins)
def score(self):
total = 0
roll_index = 0
for frame in range(10):
if self.is_strike(roll_index):
total += 10 + self.rolls[roll_index +1] + self.rolls[roll_index +2]
roll_index += 1
elif self.is_spare(roll_index):
total += 10 + self.rolls[roll_index + 2]
roll_index += 2
else:
total += self.rolls[roll_index] + self.rolls[roll_index + 1]
roll_index += 2
return total
def is_spare(self, roll_index):
return self.rolls[roll_index] + self.rolls[roll_index + 1] == 10
def is_strike(self, roll_index):
return self.rolls[roll_index] == 10
Explanation:
- Added
is_strike(roll_index)
to check if the roll is a strike. - If it's a strike, we calculate the strike bonus (next two rolls) and increment
roll_index
by 1. - The rest remains the same.
Running the Test Again¶
Run:
pytest test_bowling_game.py
Expected Result: The strike test should pass, along with the previous tests.
!pytest katta/test_bowling_game.py -sv
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 4 items katta/test_bowling_game.py::test_gitter_gane PASSED katta/test_bowling_game.py::test_all_ones PASSED katta/test_bowling_game.py::test_one_spare PASSED katta/test_bowling_game.py::test_one_strike PASSED ============================== 4 passed in 0.16s ===============================
Step 10: Refactor the Code¶
At this point, our code handles basic scoring, spares, and strikes. We can refactor to improve readability.
Refactoring BowlingGame
¶
Update bowling_game.py
:
class BowlingGame:
def __init__(self):
self.rolls = []
def roll(self, pins):
self.rolls.append(pins)
def score(self):
total = 0
roll_index = 0
for frame in range(10):
if self.is_strike(roll_index):
total += self.strike_score(roll_index)
roll_index += 1
elif self.is_spare(roll_index):
total += self.spare_score(roll_index)
roll_index += 2
else:
total += self.frame_score(roll_index)
roll_index += 2
return total
def is_spare(self, roll_index):
return self.frame_score(roll_index) == 10
def is_strike(self, roll_index):
return self.rolls[roll_index] == 10
def frame_score(self, roll_index):
return self.rolls[roll_index] + self.rolls[roll_index + 1]
def spare_score(self, roll_index):
return 10 + self.rolls[roll_index + 2]
def strike_score(self, roll_index):
return 10 + self.rolls[roll_index + 1] + self.rolls[roll_index + 2]
Explanation:
- Extracted methods
frame_score
,spare_score
, andstrike_score
for clarity. - This makes the
score()
method cleaner and easier to understand.
!pytest katta/test_bowling_game.py -sv
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 4 items katta/test_bowling_game.py::test_gitter_gane PASSED katta/test_bowling_game.py::test_all_ones PASSED katta/test_bowling_game.py::test_one_spare PASSED katta/test_bowling_game.py::test_one_strike PASSED ============================== 4 passed in 0.10s ===============================
Step 11: Write a Test for a Perfect Game¶
Test: Perfect Game¶
Objective: Ensure that a perfect game (12 strikes) scores 300.
Writing the Test¶
Add the following test to test_bowling_game.py
:
def test_perfect_game(game):
for _ in range(12):
game.roll(10)
assert game.score() == 300
Explanation:
- In a perfect game, the player rolls 12 strikes (the extra two strikes are for the 10th frame bonus).
- The total score should be 300.
Running the Test¶
Run:
pytest test_bowling_game.py
Expected Result: The test should success because our current code handle the 10th frame correctly.
!pytest katta/test_bowling_game.py -sv
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 5 items katta/test_bowling_game.py::test_gitter_gane PASSED katta/test_bowling_game.py::test_all_ones PASSED katta/test_bowling_game.py::test_one_spare PASSED katta/test_bowling_game.py::test_one_strike PASSED katta/test_bowling_game.py::test_perfect_game PASSED ============================== 5 passed in 0.10s ===============================
Step 12: Additional Tests and Edge Cases¶
Test: All Spares¶
Objective: Ensure that a game of all spares scores correctly.
Writing the Test¶
Add to test_bowling_game.py
:
def test_all_spares(game):
for _ in range(21):
game.roll(5)
assert game.score() == 150
Explanation:
- 21 rolls of 5 pins each (the extra roll is for the 10th frame spare).
- Each frame score: 10 + next roll (5).
- Total score: 10 + 5 per frame, over 10 frames: (15 * 10) = 150.
!pytest katta/test_bowling_game.py -sv
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 6 items katta/test_bowling_game.py::test_gitter_gane PASSED katta/test_bowling_game.py::test_all_ones PASSED katta/test_bowling_game.py::test_one_spare PASSED katta/test_bowling_game.py::test_one_strike PASSED katta/test_bowling_game.py::test_perfect_game PASSED katta/test_bowling_game.py::test_all_spares PASSED ============================== 6 passed in 0.12s ===============================
import pytest
from bowling_game import BowlingGame
@pytest.fixture
def game():
return BowlingGame()
def test_gitter_gane(game):
for _ in range(20):
game.roll(0)
assert game.score() == 0
def test_all_ones(game):
for _ in range(20):
game.roll(1)
assert game.score() == 20
def test_one_spare(game):
game.roll(5)
game.roll(5)
for _ in range(18):
game.roll(1)
assert game.score() == 29
def test_one_strike(game):
game.roll(10)
game.roll(3)
game.roll(2)
for _ in range(16):
game.roll(1)
assert game.score() == 36
def test_perfect_game(game):
for _ in range(12):
game.roll(10)
assert game.score() == 300
def test_all_spares(game):
for _ in range(21):
game.roll(5)
assert game.score() == 150
🏠 Exercises¶
Handle Invalid Inputs:
- Modify the
roll
method to handle invalid inputs (e.g., negative pins, pins greater than 10). - Write tests to ensure that invalid inputs raise appropriate exceptions.
- Modify the
- Write a parametrized test and test a different paths of the game
Random Game Simulation:
- Write a function to simulate a random bowling game.
- Ensure that the total score calculated matches the expected score.
Handle Incomplete Game:
- Modify the code to handle cases where
roll_index
may go out of bounds:
def test_incomplete_game_with_spare_and_strike(game): # Frame 1: Strike (10 points) game.roll(10) # Frame 2: Spare (5 + 5 = 10 points) game.roll(5) game.roll(5) # Frame 3: Partial frame (only one roll) game.roll(3) # No further rolls provided, leaving the game incomplete # Expected score: # Frame 1: Strike = 10 + 5 + 5 = 20 # Frame 2: Spare = 10 + 3 = 13 # Frame 3: 3 points # Total = 20 + 13 + 3 = 36 assert game.score() == 36
- Modify the code to handle cases where
Implement a Command-Line Interface:
- Allow users to input rolls and display the score after each frame.
- Use the
BowlingGame
class as the backend.
Discussion¶
Pros:
- Tests serve as documentation, are readable, and independent of each other (orthogonal).
- Design focuses on interfaces rather than implementation, leading to more convenient interfaces.
- It’s easy to revert to the last working version because tests are consistently maintained, and progress is made in small steps.
- TDD is a methodology that is language-agnostic and tool-agnostic.
Cons:
- It takes more time when starting from scratch. On the other hand, the difference is negligible in large projects. => It's not particularly useful for prototypes or short projects (in such cases, testing in general may not make much sense).
- You cannot follow TDD rules 100% of the time; they need to be filtered through personal experience.
- TDD doesn’t answer all questions: What about E2E tests? What about multi-layered architecture? (Hint: BDD).
The FIRST Principles of Testing¶
In this chapter, we delve into the FIRST principles of testing—Fast, Independent, Repeatable, Self-validating, and Timely/Thorough. Understanding and applying these principles will help you write effective and efficient tests for your Python applications. We'll explore each principle in detail, providing practical examples and best practices to enhance your testing methodologies.
Introduction to FIRST Principles¶
The FIRST principles serve as guidelines to create high-quality tests that are maintainable and reliable. They ensure that your test suite is:
- Fast: Encourages frequent execution of tests without hesitation.
- Independent: Ensures tests do not rely on each other's state or the environment.
- Repeatable: Guarantees consistent results every time the tests are run.
- Self-validating: Eliminates the need for manual result inspection.
- Timely/Thorough: Promotes writing tests at the right time and covering all necessary scenarios.
Fast¶
Importance of Speed in Testing¶
Tests should execute swiftly to encourage developers to run them frequently. If tests are slow, developers may hesitate to run them often, leading to less frequent integration and delayed detection of issues.
Best Practices for Fast Tests¶
- Optimize Setup and Teardown: Minimize the time spent in setting up and tearing down tests.
- Mock External Dependencies: Use mocking to simulate external systems or services.
- Run Tests in Memory: Avoid disk I/O operations if possible.
Python Example¶
import unittest
from unittest.mock import MagicMock
class FastTest(unittest.TestCase):
def test_fast_execution(self):
# Arrange
external_service = MagicMock()
external_service.get_data.return_value = {'key': 'value'}
# Act
result = external_service.get_data()
# Assert
self.assertEqual(result, {'key': 'value'})
Independent¶
The 3 As: Arrange, Act, Assert¶
- Arrange: Set up the data and environment for the test.
- Act: Execute the functionality under test.
- Assert: Verify the outcome.
Ensuring Independence¶
- Isolate Test Data: Each test should create its own data.
- No Order Dependency: Tests should pass or fail regardless of the execution order.
- Avoid Shared State: Do not rely on data modified by other tests.
Best Practices¶
- Use fixtures or setup methods to initialize data.
- Clean up any changes made during the test in the teardown phase.
- Use mocking to isolate external dependencies.
Python Example¶
import pytest
def process_data(data):
# Sample function to double the input value
return data['input'] * 2
def test_independent_behavior():
# Arrange
data = {'input': 10}
# Act
result = process_data(data)
# Assert
assert result == 20
Repeatable¶
Deterministic Results¶
Tests should produce the same results every time they are run, regardless of the environment or timing.
Avoiding Environmental Dependencies¶
- Control Randomness: Seed random number generators.
- Mock Time-dependent Functions: Replace actual time functions with fixed values.
Best Practices¶
- Use Mocking: Replace non-deterministic functions with mocks.
- Data Helpers: Use helper functions or classes to set up consistent data.
Python Examples¶
Controlling Randomness:
import random
def test_random_number():
random.seed(0)
assert random.random() == 0.8444218515250481
The random module in Python uses a pseudo-random number generator (PRNG) algorithm. When you set the seed using random.seed(0), it initializes the PRNG to a specific state. This ensures that the sequence of random numbers generated is reproducible.
The number 0.8444218515250481 is the first number in the sequence generated by the PRNG when the seed is set to 0. The PRNG algorithm used by Python's random module is deterministic, meaning that given the same initial seed, it will always produce the same sequence of numbers.
Mocking Date and Time:
from unittest.mock import patch
import datetime
def get_current_time():
return datetime.datetime.now()
@patch('datetime.datetime')
def test_time_dependent_function(mock_datetime):
mock_datetime.now.return_value = datetime.datetime(2020, 1, 1)
assert get_current_time() == datetime.datetime(2020, 1, 1)
Self-Validating¶
Automated Validation¶
Tests should automatically determine whether they pass or fail without requiring manual inspection.
Best Practices¶
- Use Assertions: Rely on assertions to validate outcomes.
- Avoid Print Statements: Do not use print statements for validation.
- Consistent Assertion Messages: Provide clear messages for assertion failures.
Python Example¶
def test_self_validating():
result = perform_calculation(5, 5)
assert result == 10, "The calculation result should be 10."
Timely/Thorough¶
Writing Tests at the Right Time¶
- Test-Driven Development (TDD): Write tests before the actual code.
- Behavior-Driven Development (BDD): Focus on the behavior and write tests accordingly.
Ensuring Thoroughness¶
- Cover Edge Cases: Test boundary conditions and corner cases.
- Test with Large Data Sets: Assess performance and scalability.
- Security Testing: Validate behavior with different user roles and permissions.
- Error Handling: Test exceptions, errors, and invalid inputs.
Best Practices¶
- Aim Beyond 100% Coverage: Focus on meaningful test cases rather than just coverage metrics.
- Use Parameterized Tests: Test multiple inputs with the same test logic.
- Include Negative Tests: Ensure the system behaves correctly with invalid inputs.
Python Examples¶
Testing Edge Cases:
def test_edge_case_zero():
result = divide(10, 0)
assert result is None, "Division by zero should return None."
Parameterized Tests with pytest
:
import pytest
@pytest.mark.parametrize("input,expected", [
(0, 1),
(1, 1),
(5, 120),
])
def test_factorial(input, expected):
assert factorial(input) == expected
Summary¶
Applying the FIRST principles ensures that your tests are efficient, reliable, and maintainable. By writing tests that are fast, independent, repeatable, self-validating, and timely/thorough, you improve the quality of your software and the confidence in your codebase.
Further Reading¶
- Test-Driven Development by Example by Kent Beck
- Python Testing with pytest by Brian Okken
- Clean Code by Robert C. Martin (Uncle Bob)
Mocking in Python Testing¶
In this chapter, we explore the concept of mocking in Python testing. Mocking allows you to isolate the system under test (SUT) by replacing parts of the system that are external or not easily testable with mock objects. This is especially useful when the SUT interacts with external resources like databases, file systems, or network services.
Introduction to Mocking¶
What is Mocking?¶
Mocking is a technique used in unit testing where you replace real objects with mock objects that simulate the behavior of the real ones. This helps in isolating the code under test and controlling the environment in which the tests run.
Why Use Mocking?¶
- Isolation: Test components in isolation without dependencies.
- Control: Simulate specific scenarios and edge cases.
- Performance: Avoid slow operations like network calls or file I/O.
- Reliability: Eliminate flakiness due to external factors.
System Under Test (SUT)¶
Let's consider a simple module as our system under test.
The my_remove_module.py
Module¶
%%writefile my_remove_module.py
# my_remove_module.py
import os
DEFAULT_EXTENSION = '.txt'
def my_remove(filename):
if '.' not in filename:
filename += DEFAULT_EXTENSION
os.remove(filename)
Overwriting my_remove_module.py
Explanation:
- Imports: The module imports the
os
module to interact with the operating system. - Constants:
DEFAULT_EXTENSION
is set to'.txt'
. - Function:
my_remove
takes afilename
and removes it from the file system.- If the filename does not contain an extension (no
.
character), it appends the default extension.
- If the filename does not contain an extension (no
Testing Without Mocking¶
Initially, we'll write tests without using mocking to understand the limitations.
Test Cases Without Mocking¶
%%writefile test_my_remove_module.py
# test_my_remove_module.py
import os
from my_remove_module import my_remove
def test_provided_extension_should_be_used():
filename = 'file.md'
# Create the file
open(filename, 'w').close()
assert os.path.isfile(filename)
# Call the function
my_remove(filename)
assert not os.path.isfile(filename)
def test_when_extension_is_missing_then_use_default_one():
filename = 'file.txt'
# Create the file
open(filename, 'w').close()
assert os.path.isfile(filename)
# Call the function with filename without extension
my_remove('file')
assert not os.path.isfile(filename)
Overwriting test_my_remove_module.py
!pytest test_my_remove_module.py -v
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 2 items test_my_remove_module.py::test_provided_extension_should_be_used PASSED [ 50%] test_my_remove_module.py::test_when_extension_is_missing_then_use_default_one PASSED [100%] ============================== 2 passed in 0.12s ===============================
Limitations Without Mocking:
- Dependency on File System: The tests interact with the actual file system.
- Side Effects: Creates and deletes real files, which can be risky.
- Performance: File I/O operations can slow down the tests.
- Environment Sensitivity: Tests may fail if the environment doesn't permit file operations.
Testing With Mocking¶
To overcome the limitations, we'll use mocking to simulate the file system operations.
Using unittest.mock
¶
Python's unittest.mock
library provides tools for mocking objects in tests.
Refactoring Tests with Mocking¶
%%writefile test_my_remove_module_mock.py
# test_my_remove_module_mock.py
from unittest import mock
from my_remove_module import my_remove
@mock.patch('my_remove_module.os')
def test_provided_extension_should_be_used(mock_os):
filename = 'file.md'
my_remove(filename)
mock_os.remove.assert_called_once_with(filename)
@mock.patch('my_remove_module.os')
def test_when_extension_is_missing_then_use_default_one(mock_os):
my_remove('file')
# my_remove('file')
mock_os.remove.assert_called_once_with('file.txt')
Overwriting test_my_remove_module_mock.py
!pytest test_my_remove_module_mock.py -v
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 2 items test_my_remove_module_mock.py::test_provided_extension_should_be_used PASSED [ 50%] test_my_remove_module_mock.py::test_when_extension_is_missing_then_use_default_one PASSED [100%] ============================== 2 passed in 0.19s ===============================
Explanation:
- Decorators: Use
@mock.patch('my_remove_module.os')
to replace theos
module inmy_remove_module
with a mock. - Mock Objects:
mock_os
is the mocked version of theos
module.
- Assertions:
assert_called_once_with
verifies thatos.remove
was called exactly once with the specified argument.
Understanding mock.patch
¶
How mock.patch
Works¶
- Target: The string
'my_remove_module.os'
specifies the exact location to patch.- It means, "In the module
my_remove_module
, replaceos
with a mock."
- It means, "In the module
- Replacement: The original
os
module is temporarily replaced with a mock during the test.
Common Pitfall¶
It's crucial to patch the object in the module where it's used, not where it's imported from. Patching os
directly won't affect my_remove_module.os
because my_remove_module
has its own reference to os
.
Manual Patching¶
Sometimes, you might need to patch objects manually without decorators.
Manual Patching Example¶
%%writefile test_my_remove_module_manual.py
# test_my_remove_module_manual.py
from unittest import mock
import my_remove_module
from my_remove_module import my_remove
def test_when_extension_is_missing_then_use_default_one():
# Save the real os module
real_os = my_remove_module.os
# Replace os with a mock
my_remove_module.os = mock.MagicMock()
try:
filename = 'file.md'
my_remove(filename)
# Assert os.remove was called correctly
my_remove_module.os.remove.assert_called_once_with(filename)
finally:
# Restore the real os module
my_remove_module.os = real_os
Overwriting test_my_remove_module_manual.py
!pytest test_my_remove_module_manual.py -v
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 1 item test_my_remove_module_manual.py::test_when_extension_is_missing_then_use_default_one PASSED [100%] ============================== 1 passed in 0.18s ===============================
Explanation:
- Manual Replacement:
- Save the real
os
module. - Replace
my_remove_module.os
with a mock object.
- Save the real
- Try-Finally Block:
- Ensure the real
os
module is restored after the test, even if an exception occurs.
- Ensure the real
- Assertion:
- Check that
os.remove
was called with the expected filename.
- Check that
When to Use Manual Patching¶
- When you need more control over the patching process.
- When decorators or context managers are not suitable.
Mocking Functions¶
Essentials¶
The unittest.mock
library provides the MagicMock
class, which is a powerful and flexible mock object that can mimic any Python object.
Creating a MagicMock instance:
from unittest import mock
m = mock.MagicMock()
print(m)
print(m())
print(m().revert(0).delete())
<MagicMock id='4398390144'> <MagicMock name='mock()' id='4397114688'> <MagicMock name='mock().revert().delete()' id='4393699280'>
m
: AMagicMock
instance.m()
: Calling the mock as if it were a function returns anotherMagicMock
instance.
Setting Return Values:
You can specify the return value of a mock function using the return_value
attribute.
m = mock.MagicMock()
m.return_value = 42
result = m(84, foo=3)
assert result == 42
print(result)
42
Asserting Calls:
You can assert that a mock was called with specific arguments.
m.assert_called_once_with(84, foo=3)
If the mock was not called with the specified arguments exactly once, an AssertionError
is raised.
Checking Call Arguments:
You can inspect how a mock was called using call_args
.
print(m.call_args)
call(84, foo=3)
assert m.call_args == mock.call(84, foo=3)
Example: Mocking a Function in a Module¶
Let's consider a module power_reset.py
that performs a POST request to reset power.
power_reset.py
:
%%writefile power_reset.py
# power_reset.py
import requests
def power_reset():
ret = requests.post('https://127.0.0.1:8000/power_reset', json={})
print(ret)
if ret != 42:
raise Exception("Power reset failed")
Overwriting power_reset.py
Explanation:
- The
power_reset
function sends a POST request to a local URL. - It prints the response and raises an exception if the response is not
42
.
Writing a Test with Mocking:
We can mock the requests
module to simulate the POST request without actually performing it.
test_power_reset.py
:
%%writefile test_power_reset.py
# test_power_reset.py
from unittest import mock
from power_reset import power_reset
@mock.patch('power_reset.requests')
def test_power_reset(mock_requests):
# Arrange
mock_requests.post.return_value = 42 # Mock the return value of requests.post
# Act
power_reset()
# Assert
mock_requests.post.assert_called_once_with(
'https://127.0.0.1:8000/power_reset', json={}
)
Overwriting test_power_reset.py
*Explanation:**
- We use
@mock.patch('power_reset.requests')
to replace therequests
module inpower_reset
with a mock. - We set the
return_value
ofmock_requests.post
to42
. - We call
power_reset()
and assert thatrequests.post
was called once with the specified URL and JSON data.
Running the Test:
!pytest test_power_reset.py -vs
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 1 item test_power_reset.py::test_power_reset 42 PASSED ============================== 1 passed in 0.61s ===============================
Raising Exceptions with Mocks¶
You can configure a mock to raise an exception when called, using the side_effect
attribute.
Example:
m = mock.MagicMock()
m.side_effect = ZeroDivisionError
try:
m()
except ZeroDivisionError:
print("Caught ZeroDivisionError as expected.")
Caught ZeroDivisionError as expected.
Explanation:
- Setting
m.side_effect = ZeroDivisionError
causes the mock to raiseZeroDivisionError
when called. - This is useful for testing how your code handles exceptions from dependencies.
Returning Multiple Values¶
You can use side_effect
to specify a list of return values or exceptions for successive calls.
Example:
m = mock.MagicMock()
m.side_effect = [1, 2, KeyError, 3]
print(m()) # Outputs: 1
print(m()) # Outputs: 2
try:
m() # Raises KeyError
except KeyError:
print("Caught KeyError as expected.")
print(m()) # Outputs: 3
1 2 Caught KeyError as expected. 3
Explanation:
- Each call to
m()
returns the next value in theside_effect
list. - If an exception is encountered in the list, it is raised when the mock is called.
Mocking with Functions and Lambdas¶
You can set side_effect
to a function or lambda to dynamically determine the return value based on the input arguments.
Using a Lambda Function:
m = mock.MagicMock()
m.side_effect = lambda x: x + 2
print(m(5)) # Outputs: 7
7
Using a Defined Function:
def add_two(x):
return x + 2
m = mock.MagicMock()
m.side_effect = add_two
print(m(5)) # Outputs: 7
7
def add_two(x):
return x + 2
m = mock.MagicMock()
m.side_effect = add_two
print(m(5,6)) # Outputs: 7
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) Cell In[82], line 7 4 m = mock.MagicMock() 5 m.side_effect = add_two ----> 7 print(m(5,6)) # Outputs: 7 File ~/.pyenv/versions/3.12.6/lib/python3.12/unittest/mock.py:1137, in CallableMixin.__call__(self, *args, **kwargs) 1135 self._mock_check_sig(*args, **kwargs) 1136 self._increment_mock_call(*args, **kwargs) -> 1137 return self._mock_call(*args, **kwargs) File ~/.pyenv/versions/3.12.6/lib/python3.12/unittest/mock.py:1141, in CallableMixin._mock_call(self, *args, **kwargs) 1140 def _mock_call(self, /, *args, **kwargs): -> 1141 return self._execute_mock_call(*args, **kwargs) File ~/.pyenv/versions/3.12.6/lib/python3.12/unittest/mock.py:1202, in CallableMixin._execute_mock_call(self, *args, **kwargs) 1200 raise result 1201 else: -> 1202 result = effect(*args, **kwargs) 1204 if result is not DEFAULT: 1205 return result TypeError: add_two() takes 1 positional argument but 2 were given
def add_two(x, y):
return x + y + 2
m = mock.MagicMock()
m.side_effect = add_two
print(m(5, 6)) # Outputs: 13
13
Explanation:
- The mock will call the function specified in
side_effect
when invoked. - This allows for more complex logic in determining the return value.
Assertions on Mocks¶
Mocks provide several assertion methods to check how they were called.
Example:
m = mock.MagicMock()
m(10)
m(20)
# Assert that the last call was with 20
m.assert_called_with(20)
# This will raise an AssertionError because the mock was called twice
try:
m.assert_called_once_with(20)
except AssertionError:
print("AssertionError as expected: The mock was called more than once.")
# Assert that the mock was called at least once with 10
m.assert_any_call(10)
# Assert that the mock was called at least once with specific arguments
m.assert_any_call(20)
AssertionError as expected: The mock was called more than once.
Common Assertion Methods:
- `assert_called_with(*args, kwargs)`**: Asserts the last call's arguments.
- `assert_called_once_with(*args, kwargs)`**: Asserts the mock was called exactly once with the specified arguments.
- `assert_any_call(*args, kwargs)`**: Asserts the mock was called with the specified arguments at least once.
assert_not_called()
: Asserts the mock was never called.
Inspecting Mock Attributes¶
Mocks keep track of how they were called, providing useful attributes for inspection.
Example:
m = mock.MagicMock()
m(5)
m(7)
print(m.called) # Outputs: True
print(m.call_count) # Outputs: 2
print(m.call_args) # Outputs: call(7)
print(m.call_args_list) # Outputs: [call(5), call(7)]
# Assert the call history
assert m.call_args_list == [mock.call(5), mock.call(7)]
True 2 call(7) [call(5), call(7)]
Attributes Explained:
called
:True
if the mock has been called at least once.call_count
: Number of times the mock has been called.call_args
: Arguments of the most recent call.call_args_list
: List of all call arguments in order.
Practical Example with Assertions¶
Let's revisit the power_reset
example and add more assertions.
Updated test_power_reset.py
:
from unittest import mock
from power_reset import power_reset
@mock.patch('power_reset.requests')
def test_power_reset(mock_requests):
# Arrange
mock_requests.post.return_value = 42
# Act
power_reset()
# Assert
assert mock_requests.post.called
assert mock_requests.post.call_count == 1
mock_requests.post.assert_called_once_with(
'https://127.0.0.1:8000/power_reset', json={}
)
# extra
assert mock_requests.post.call_args == mock.call(
'https://127.0.0.1:8000/power_reset', json={}
)
Explanation:
- We assert that
requests.post
was called. - We check the call count to ensure it was called exactly once.
- We use
assert_called_once_with
to verify the arguments. - We inspect
call_args
to see the details of the call.
Common Pitfalls¶
- Patching the Wrong Target: Always patch the object in the module where it is looked up, not where it is defined.
- Not Restoring Original Objects: If you manually replace objects, ensure you restore them after the test.
- Mocks Return Mocks: By default, calling a mock returns another mock. Set
return_value
orside_effect
to control this behavior.
🏠 Exercise: Mocking the open
Function¶
Objective¶
Write a test that mocks the built-in open
function to verify the behavior of the print_file
function. Your test should:
- Ensure that
print_file('file.txt')
callsopen('file.txt')
. - Verify that the
print
function is called (without checking its arguments). - Confirm that the
print
function is called with whateveropen().__enter__().read()
returns.
Instructions¶
- Use the
unittest.mock
library to mock theopen
andprint
functions. - Implement the test in such a way that it passes all the assertions.
- Remember to handle the context manager methods
__enter__
and__exit__
when mockingopen
.
Starting Code¶
Here is the code for the print_file
function:
# mocking_open.py
def print_file(filename):
with open(filename) as stream:
content = stream.read()
# Equivalent to:
# context_manager = open(filename)
# stream = context_manager.__enter__()
# content = stream.read()
# context_manager.__exit__()
print(content)
Your task is to write a test for the print_file
function that satisfies the objectives listed above.
Solution¶
%%writefile mocking_open.py
# mocking_open.py
def print_file(filename):
with open(filename) as stream:
content = stream.read()
# Equivalent to:
# context_manager = open(filename)
# stream = context_manager.__enter__()
# content = stream.read()
# context_manager.__exit__()
print(content)
Overwriting mocking_open.py
%%writefile test_mocking_open.py
# test_mocking_open.py
from unittest import mock
from mocking_open import print_file
@mock.patch('mocking_open.open')
@mock.patch('mocking_open.print')
def test_print_file(print_mock, open_mock):
content = '42'
# Arrange
# Mock the context manager returned by open()
context_manager = open_mock.return_value
# Mock the stream returned by __enter__()
stream = context_manager.__enter__.return_value
# Mock the return value of stream.read()
stream.read.return_value = content
# Bad Arrange
# open_mock.return_value.__enter__.return_value.read.return_value = content => stream .read.return_value = content
# Act
print_file('file.txt')
# Assert
# Ensure open was called once with 'file.txt'
open_mock.assert_called_once_with('file.txt')
# Ensure print was called (without checking arguments)
assert print_mock.called
# Ensure print was called with the content read from the file
print_mock.assert_called_once_with(content)
# Optionally, check that __enter__ and __exit__ were called
context_manager.__enter__.assert_called_once_with()
context_manager.__exit__.assert_called_once_with(None, None, None)
Overwriting test_mocking_open.py
!pytest test_mocking_open.py -v
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 1 item test_mocking_open.py::test_print_file PASSED [100%] ============================== 1 passed in 0.18s ===============================
m = mock.MagicMock()
print(m.foo(42, 'bar'))
print(m.foo(42, 'bar'))
print(m.foo())
print(m.boo())
print(m.foo().bar())
a = m.foo()
print(a.bar())
<MagicMock name='mock.foo()' id='4506124160'> <MagicMock name='mock.foo()' id='4506124160'> <MagicMock name='mock.foo()' id='4506124160'> <MagicMock name='mock.boo()' id='4398380624'> <MagicMock name='mock.foo().bar()' id='4398737760'> <MagicMock name='mock.foo().bar()' id='4398737760'>
Explanation¶
Import Statements¶
from unittest import mock
from mocking_open import print_file
unittest.mock
: Provides a library for mocking objects in tests.print_file
: The function we're testing.
Mocking with @mock.patch
¶
@mock.patch('mocking_open.open')
@mock.patch('mocking_open.print')
def test_print_file(print_mock, open_mock):
@mock.patch('mocking_open.open')
: Mocks theopen
function in themocking_open
module.@mock.patch('mocking_open.print')
: Mocks theprint
function in themocking_open
module.- The order of decorators is important; mocks are passed into the test function in reverse order.
Arrange¶
content = '42'
# Mock the context manager returned by open()
context_manager = open_mock.return_value
# Mock the stream returned by __enter__()
stream = context_manager.__enter__.return_value
# Mock the return value of stream.read()
stream.read.return_value = content
content
: The mocked content to be read from the file.open_mock.return_value
: Represents the context manager returned byopen()
.context_manager.__enter__.return_value
: Represents the file stream returned by the context manager's__enter__
method.stream.read.return_value = content
: Specifies the return value whenread()
is called on the stream.
Act¶
print_file('file.txt')
- Calls the function under test with the specified filename.
Assert¶
# Ensure open was called once with 'file.txt'
open_mock.assert_called_once_with('file.txt')
# Ensure print was called (without checking arguments)
assert print_mock.called
# Ensure print was called with the content read from the file
print_mock.assert_called_once_with(content)
# Optionally, check that __enter__ and __exit__ were called
context_manager.__enter__.assert_called_once_with()
context_manager.__exit__.assert_called_once_with(None, None, None)
open_mock.assert_called_once_with('file.txt')
: Verifies thatopen
was called exactly once with'file.txt'
.assert print_mock.called
: Verifies thatprint
was called.print_mock.assert_called_once_with(content)
: Verifies thatprint
was called with the expected content.- Context Manager Assertions:
context_manager.__enter__.assert_called_once_with()
: Ensures the__enter__
method was called.context_manager.__exit__.assert_called_once_with(None, None, None)
: Ensures the__exit__
method was called with the correct arguments (no exceptions).
Running the Test¶
To run the test, use the following command:
pytest test_mocking_open.py
Make sure you have pytest
installed and that both mocking_open.py
and test_mocking_open.py
are in the same directory.
Additional Notes¶
Handling the Context Manager:
- When mocking
open
, you need to handle the fact that it returns a context manager. - The context manager's
__enter__
method returns the file stream, which you can mock to control the return value ofread()
.
- When mocking
Order of Mocks:
- In the test function parameters,
print_mock
comes beforeopen_mock
because the mocks are applied in the reverse order of the decorators.
- In the test function parameters,
Why Mock
print
?:- Mocking
print
allows you to verify that it's called and to check the arguments it was called with, without actually printing to the console during the test.
- Mocking
Verifying Calls Without Arguments:
- If you only want to verify that a function was called, you can use
assert mock.called
ormock.assert_called()
.
- If you only want to verify that a function was called, you can use
Advanced Mocking Techniques¶
In this section, we'll delve deeper into Python's unittest.mock
library, exploring more advanced features and nuances of mocks. Understanding these concepts will help you write more precise and effective tests.
Returning a Mock from a Mock¶
When you call a MagicMock
instance, it returns another MagicMock
instance by default. This allows you to chain calls or access attributes on the returned mock.
Example:
from unittest import mock
m = mock.MagicMock()
print(m) # Outputs: <MagicMock id='...'>
print(m()) # Outputs: <MagicMock name='mock()' id='...'>
print(m()) # Outputs the same as above
print(m().m()) # Outputs: MagicMock name='mock().m()' id='4458159264'
print(m().m()) # Outputs the same as above
<MagicMock id='4506071936'> <MagicMock name='mock()' id='4506068384'> <MagicMock name='mock()' id='4506068384'> <MagicMock name='mock().m()' id='4512817200'> <MagicMock name='mock().m()' id='4512817200'>
- Explanation:
m
: AMagicMock
instance.m()
: Callingm
returns anotherMagicMock
instance named'mock()'
.- Each call to
m()
returns the same mock object unless configured otherwise.
Setting Return Values:
You can specify what a mock should return when called.
m = mock.MagicMock()
m.return_value = 42
print(m) # Outputs: <MagicMock id='...'>
print(m()) # Outputs: 42
print(m()) # Outputs: 42
<MagicMock id='4506117296'> 42 42
- Explanation:
m.return_value = 42
: Sets the return value whenm
is called.- Subsequent calls to
m()
return42
.
Mock Attributes and Methods¶
Mocks can have attributes and methods that you can configure and inspect.
Example with Attributes:
m = mock.MagicMock()
print(m)
print(m.field)
<MagicMock id='4512848768'> <MagicMock name='mock.field' id='4512919536'>
m = mock.MagicMock()
m.field = 42
print(m) # Outputs: <MagicMock id='...'>
print(m.field) # Outputs: 42
print(m.field) # Outputs: 42
<MagicMock id='4513093712'> 42 42
m = mock.MagicMock()
# m.field = 42
m.field.return_value = 42
print(m) # Outputs: <MagicMock id='...'>
print(m.field()) # Outputs: 42
print(m.field) # Outputs: 42
<MagicMock id='4512935872'> 42 <MagicMock name='mock.field' id='4506073520'>
- Explanation:
- You can set attributes on a mock as you would on a regular object.
Mocking Methods:
If you access an attribute or method that hasn't been set, the mock will create it on the fly as another MagicMock
instance.
m = mock.Mock()
# No arrangement for m.field or m.method
print(m) # Outputs: <Mock id='...'>
print(m.field) # Outputs: <Mock name='mock.field' id='...'>
print(m.method) # Outputs: <Mock name='mock.method' id='...'>
print(m.method(3)) # Outputs: <Mock name='mock.method()' id='...'>
print(m()) # Outputs: <Mock name='mock()' id='...'>
print(m.mock_calls) # Outputs the list of calls made on the mock
<Mock id='4399320048'> <Mock name='mock.field' id='4513092560'> <Mock name='mock.method' id='4512935680'> <Mock name='mock.method()' id='4397134528'> <Mock name='mock()' id='4398737856'> [call.method(3), call()]
MagicMock
vs. Mock
¶
The MagicMock
class is a subclass of Mock
that includes default implementations of the magic methods (dunder methods) in Python, such as __len__
, __getitem__
, etc.
Differences in Behavior:
mm = mock.MagicMock()
m = mock.Mock()
# Using magic methods on MagicMock
print(mm['a'])
print(m['a'])
<MagicMock name='mock.__getitem__()' id='4513010688'>
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) Cell In[107], line 7 4 # Using magic methods on MagicMock 5 print(mm['a']) ----> 7 print(m['a']) TypeError: 'Mock' object is not subscriptable
mm = mock.MagicMock()
m = mock.Mock()
# Using magic methods on MagicMock
print(len(mm))
print(len(m))
0
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) Cell In[109], line 7 4 # Using magic methods on MagicMock 5 print(len(mm)) ----> 7 print(len(m)) TypeError: object of type 'Mock' has no len()
mm = mock.MagicMock()
m = mock.Mock()
# Using magic methods on MagicMock
print( 42 in mm)
print(42 in m)
False
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) Cell In[111], line 7 4 # Using magic methods on MagicMock 5 print( 42 in mm) ----> 7 print(42 in m) TypeError: argument of type 'Mock' is not iterable
- Explanation:
MagicMock
supports magic methods by default, making it suitable for mocking objects that use them.Mock
does not support magic methods unless you specify them.
Using spec
and autospec
¶
The spec
parameter in mocks restricts the mock to only have attributes and methods that exist on the specified object.
Example without Specifying spec
:
dumb_os_mock = mock.MagicMock()
dumb_os_mock.remoev() # Typo in method name, but no error
dumb_os_mock.remove.asser_called_once_with('another file') # Typo in method name, but no error
<MagicMock name='mock.remove.asser_called_once_with()' id='4506093424'>
- Issue:
- Since we didn't specify a
spec
, the mock allows any attribute or method, leading to potential silent failures.
- Since we didn't specify a
Using spec
to Enforce the Interface:
import os
os_mock = mock.MagicMock(spec=os)
os_mock.remove('file') # Correct usage
# Accessing a non-existent attribute raises AttributeError
# os_mock.remoev # Raises AttributeError
# Typo in method name during assertion is still not caught
os_mock.remove.asser_called_once_with('another_file') # Typo in method name
<MagicMock name='mock.remove.asser_called_once_with()' id='4505391360'>
- Explanation:
spec=os
restricts the mock to have only attributes and methods that exist in theos
module.- Accessing
os_mock.remoev
(with a typo) raises anAttributeError
.
Limitation:
- Using
spec
does not prevent typos in methods of the mock's methods (likeasser_called_once_with
instead ofassert_called_once_with
).
spec
and Instance Attributes¶
Using spec
or autospec
with classes can sometimes lead to unexpected behavior, especially with instance attributes.
Example:
class Something:
foo: int
def __init__(self):
self.foo = 42
something_mock = mock.Mock(spec=Something)
something_mock.foo # Raises AttributeError
--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) Cell In[122], line 7 4 self.foo = 42 6 something_mock = mock.Mock(spec=Something) ----> 7 something_mock.foo # Raises AttributeError File ~/.pyenv/versions/3.12.6/lib/python3.12/unittest/mock.py:658, in NonCallableMock.__getattr__(self, name) 656 elif self._mock_methods is not None: 657 if name not in self._mock_methods or name in _all_magics: --> 658 raise AttributeError("Mock object has no attribute %r" % name) 659 elif _is_magic(name): 660 raise AttributeError(name) AttributeError: Mock object has no attribute 'foo'
- Issue:
- The mock does not have the
foo
attribute becausespec
only sets up methods and attributes at the class level, not instance attributes set in__init__
.
- The mock does not have the
Solution:
- Manually set the attribute on the mock:
# test_my_remove_module_mock.py
from unittest import mock
from my_remove_module import my_remove
import my_remove_module
@mock.patch('my_remove_module.os', spec=my_remove_module.os)
def test_provided_extension_should_be_used(mock_os):
filename = 'file.md'
my_remove(filename)
mock_os.remove.assert_called_once_with(filename)
@mock.patch('my_remove_module.os', autospec=True)
def test_when_extension_is_missing_then_use_default_one(mock_os):
my_remove('file')
# my_remove('file')
mock_os.remove.assert_called_once_with('file.txt')
something_mock.foo = 42
print(something_mock.foo)
42
Using mock.patch
as a Context Manager¶
The mock.patch
function can be used as a context manager to temporarily replace an object during a specific block of code.
Example:
# import builtins
with mock.patch('builtins.sum') as sum_mock:
print(sum) # Outputs: <MagicMock name='sum' id='...'>
print(sum_mock) # Outputs: <MagicMock name='sum' id='...'>
print(sum) # Outputs: <built-in function sum>
<MagicMock name='sum' id='4399299376'> <MagicMock name='sum' id='4399299376'> <built-in function sum>
Using mock.ANY
¶
When you need to assert a call was made with certain arguments but want to ignore some arguments, you can use mock.ANY
.
mock.ANY
# is equivalent to:
class ANYClass:
def __eq__(self, other):
return True
ANY = ANYClass()
with mock.patch('__main__.open') as open_mock:
with open('qwer', 'r') as s:
s.read()
open_mock.assert_called_once_with(mock.ANY, 'r')
Using wraps
in Mocks¶
The wraps
parameter allows a mock to pass calls through to the original object while still recording how it was used.
Example:
import requests
from unittest import mock
requests_mock = mock.MagicMock(wraps=requests)
response = requests_mock.get('https://api.github.com')
print(response.status_code) # Outputs: 200 (assuming the request succeeds)
print(requests_mock.mock_calls) # Records the calls made on the mock
print(requests_mock.get.call_args)
200 [call.get('https://api.github.com')] call('https://api.github.com')
- Explanation:
wraps=requests
means the mock will delegate method calls to the realrequests
module.- The mock still records the calls, allowing you to assert how it was used.
Use Cases for wraps
:
- When you want to monitor interactions with an object without completely mocking out its behavior.
- Useful for spying on real objects.
Further Reading¶
- unittest.mock Documentation: Official Python documentation for the
unittest.mock
library. - Advanced Mocking with Python's mock Library: An in-depth article covering advanced mocking techniques.
- Python Testing with pytest by Brian Okken.
Exercise: Implementing and Testing an NBP API Wrapper with Mocking¶
Objective¶
In this exercise, you will implement the get_exchange_rate
function, which retrieves currency exchange rates from the National Bank of Poland (NBP) API. You will also write unit tests for this function using pytest
and unittest.mock
to ensure it behaves correctly under various scenarios without making actual HTTP requests.
Description¶
You are provided with a partial implementation of the get_exchange_rate
function in the nbp_api.py
module. This function is designed to fetch the exchange rate for a specified currency and date from the NBP API. Your tasks are:
Implement the
get_exchange_rate
Function:- Complete the function to make an HTTP GET request to the NBP API.
- Parse the JSON response to extract the exchange rate for the specified currency.
- Handle cases where no data is available for the given date by raising a
NoData
exception. - Handle invalid currency codes by raising a
ValueError
.
Write Unit Tests for
get_exchange_rate
:- Use
pytest
andunittest.mock
to mock external dependencies (requests.get
). - Test the function's behavior under different scenarios:
- Successful retrieval of exchange rates.
- No data available for the given date.
- Invalid currency codes.
- Use
Initial Code¶
nbp_api.py
# nbp_api.py
from datetime import date
import requests
class NoData(Exception):
"""Exception raised when no data is available for the given date."""
pass
API_ENDPOINT = 'http://api.nbp.pl/api/exchangerates/tables/a/{date}/?format=json'
DATE_FORMAT = '%Y-%m-%d'
def get_exchange_rate(currency: str, date: date) -> float:
"""Fetches the exchange rate for a given currency and date from the NBP API.
Args:
currency (str): The currency code (e.g., 'USD', 'EUR').
date (date): The date for which to retrieve the exchange rate.
Returns:
float: The exchange rate.
Raises:
NoData: If no data is available for the given date.
ValueError: If the provided currency code is invalid.
"""
currency = currency.upper()
date_as_str = date.strftime(DATE_FORMAT)
url = API_ENDPOINT.format(date=date_as_str)
# TODO: Make an HTTP GET request to the API_ENDPOINT.
# TODO: Check if the response status code indicates success.
# TODO: Parse the JSON response.
# TODO: Extract the exchange rate for the specified currency.
# TODO: If the currency code is not found, raise ValueError.
``
`test_nbp_api.py`
```python
# test_nbp_api.py
import pytest
from unittest import mock
from datetime import date
from nbp_api import get_exchange_rate, NoData
# You will write your tests here.
Advice for Takers¶
Understand the Functionality: Before writing tests, ensure you understand what the
get_exchange_rate
function is supposed to do, including its inputs, outputs, and how it handles errors.Mock External Calls: Use
unittest.mock
to mock therequests.get
method. This prevents actual HTTP requests during testing and allows you to simulate different responses.Use
autospec
: When mockingrequests.get
, use theautospec=True
parameter. This ensures that the mock object matches the signature of the realrequests.get
method, helping catch errors like incorrect arguments.Test Different Scenarios:
- Successful Response: Simulate a successful API response with valid exchange rate data.
- No Data Available: Simulate a response with a non-200 status code to trigger the
NoData
exception. - Invalid Currency Code: Simulate a successful response that does not include the requested currency, triggering a
ValueError
. - Network Exceptions: Simulate network-related exceptions (e.g., connection errors) to ensure your function handles them gracefully.
Use
mock.ANY
: When asserting calls where some arguments are irrelevant or variable, usemock.ANY
to ignore those specific arguments.Handle Exceptions in Tests: Use
pytest.raises
to verify that your function raises the appropriate exceptions under error conditions.Keep Tests Isolated: Ensure that each test is independent and does not rely on the state or side effects from other tests.
Run Tests Frequently: As you implement the function and write tests, run them frequently to catch and fix issues early.
Expected Behaviour¶
Successful Retrieval¶
get_exchange_rate('USD', date(2021, 1, 13))
3.7142
No Data Available¶
get_exchange_rate('USD', date(2021, 1, 10))
Traceback (most recent call last):
File "<ipython-input-164-c1f5ecb2786e>", line 1, in <module>
get_exchange_rate('USD', date(2021, 1, 10))
File "<ipython-input-162-71c2eb3d0b2f>", line 24, in get_exchange_rate
raise NoData
NoData
Invalid Currency Code¶
get_exchange_rate('XXX', date(2021, 1, 13))
Traceback (most recent call last):
File "<ipython-input-165-0d9eb2e858cd>", line 1, in <module>
get_exchange_rate('XXX', date(2021, 1, 13))
File "<ipython-input-162-71c2eb3d0b2f>", line 37, in get_exchange_rate
raise ValueError("Invalid currency")
ValueError: Invalid currency
Solution¶
# %%writefile nbp_api.py
from datetime import date
import os
import sys
import requests
class NoData(Exception):
pass
API_ENDPOINT = 'http://api.nbp.pl/api/exchangerates/tables/a/{date}/?format=json'
DATE_FORMAT = '%Y-%m-%d'
def get_exchange_rate(currency: str, date: date) -> float:
""" May raise NoData. """
currency = currency.upper()
date_as_str = date.strftime(DATE_FORMAT)
url = API_ENDPOINT.format(date=date_as_str)
response = requests.get(url)
if response.status_code != 200:
raise NoData
json = response.json()
# rates = (r['mid'] for r in json[0]['rates'] if r['code'] == currency)
# try:
# rate = next(rates)
# except StopIteration:
# raise ValueError("Invalid currency")
# else:
# return rate
for rate in json[0]['rates']:
if rate['code'] == currency:
return rate['mid']
raise ValueError("Invalid currency")
get_exchange_rate('USD', date(2021, 1, 13))
3.7142
%%writefile nbp_api_tests.py
import datetime
import json
from unittest import mock
import pytest
from nbp_api import get_exchange_rate, NoData
RAW_JSON: str = """[{"table":"A","no":"217/A/NBP/2019","effectiveDate":"2019-11-08","rates":[{"currency":"bat (Tajlandia)","code":"THB","mid":0.1272},{"currency":"dolar amerykański","code":"USD","mid":3.8625},{"currency":"dolar australijski","code":"AUD","mid":2.6533},{"currency":"dolar Hongkongu","code":"HKD","mid":0.4935},{"currency":"dolar kanadyjski","code":"CAD","mid":2.9263},{"currency":"dolar nowozelandzki","code":"NZD","mid":2.4520},{"currency":"dolar singapurski","code":"SGD","mid":2.8406},{"currency":"euro","code":"EUR","mid":4.2638},{"currency":"forint (Węgry)","code":"HUF","mid":0.01278},{"currency":"frank szwajcarski","code":"CHF","mid":3.8797},{"currency":"funt szterling","code":"GBP","mid":4.9476},{"currency":"hrywna (Ukraina)","code":"UAH","mid":0.1577},{"currency":"jen (Japonia)","code":"JPY","mid":0.035324},{"currency":"korona czeska","code":"CZK","mid":0.1669},{"currency":"korona duńska","code":"DKK","mid":0.5706},{"currency":"korona islandzka","code":"ISK","mid":0.030964},{"currency":"korona norweska","code":"NOK","mid":0.4221},{"currency":"korona szwedzka","code":"SEK","mid":0.3991},{"currency":"kuna (Chorwacja)","code":"HRK","mid":0.5739},{"currency":"lej rumuński","code":"RON","mid":0.8956},{"currency":"lew (Bułgaria)","code":"BGN","mid":2.1800},{"currency":"lira turecka","code":"TRY","mid":0.6713},{"currency":"nowy izraelski szekel","code":"ILS","mid":1.1053},{"currency":"peso chilijskie","code":"CLP","mid":0.005206},{"currency":"peso filipińskie","code":"PHP","mid":0.0764},{"currency":"peso meksykańskie","code":"MXN","mid":0.2012},{"currency":"rand (Republika Południowej Afryki)","code":"ZAR","mid":0.2603},{"currency":"real (Brazylia)","code":"BRL","mid":0.9419},{"currency":"ringgit (Malezja)","code":"MYR","mid":0.9344},{"currency":"rubel rosyjski","code":"RUB","mid":0.0605},{"currency":"rupia indonezyjska","code":"IDR","mid":0.00027562},{"currency":"rupia indyjska","code":"INR","mid":0.054185},{"currency":"won południowokoreański","code":"KRW","mid":0.003337},{"currency":"yuan renminbi (Chiny)","code":"CNY","mid":0.5523},{"currency":"SDR (MFW)","code":"XDR","mid":5.2969}]}]"""
JSON: list = json.loads(RAW_JSON)
@mock.patch('nbp_api.requests.get')
def test_should_return_exchange_rate_for_a_working_date(requests_get_mock):
response_mock = requests_get_mock.return_value
response_mock.status_code = 200
response_mock.json.return_value = JSON
date = datetime.date(2019, 11, 8)
rate = get_exchange_rate('USD', date)
assert rate == 3.8625
requests_get_mock.assert_called_once_with(
'http://api.nbp.pl/api/exchangerates/tables/a/2019-11-08/?format=json')
@mock.patch('nbp_api.requests')
def test_should_raise_NoData_for_a_nonworking_date(requests_mock):
response_mock = requests_mock.get.return_value
response_mock.status_code = 404
date = datetime.date(2019, 11, 9)
with pytest.raises(NoData):
get_exchange_rate('USD', date)
requests_mock.get.assert_called_once_with(
'http://api.nbp.pl/api/exchangerates/tables/a/2019-11-09/?format=json')
response_mock.json.assert_not_called()
@mock.patch('nbp_api.requests')
def test_should_raise_ValueError_when_currency_not_found(requests_mock):
response_mock = requests_mock.get.return_value
response_mock.status_code = 200
response_mock.json.return_value = JSON
date = datetime.date(2019, 11, 8)
with pytest.raises(ValueError):
get_exchange_rate('XXX', date)
Overwriting nbp_api_tests.py
!pytest nbp_api_tests.py -v
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 3 items nbp_api_tests.py::test_should_return_exchange_rate_for_a_working_date PASSED [ 33%] nbp_api_tests.py::test_should_raise_NoData_for_a_nonworking_date PASSED [ 66%] nbp_api_tests.py::test_should_raise_ValueError_when_currency_not_found PASSED [100%] ============================== 3 passed in 0.40s ===============================
%%writefile nbp_api_tests.py
import datetime
import json
from unittest import mock
import pytest
from nbp_api import get_exchange_rate, NoData
RAW_JSON: str = """[{"table":"A","no":"217/A/NBP/2019","effectiveDate":"2019-11-08","rates":[{"currency":"bat (Tajlandia)","code":"THB","mid":0.1272},{"currency":"dolar amerykański","code":"USD","mid":3.8625},{"currency":"dolar australijski","code":"AUD","mid":2.6533},{"currency":"dolar Hongkongu","code":"HKD","mid":0.4935},{"currency":"dolar kanadyjski","code":"CAD","mid":2.9263},{"currency":"dolar nowozelandzki","code":"NZD","mid":2.4520},{"currency":"dolar singapurski","code":"SGD","mid":2.8406},{"currency":"euro","code":"EUR","mid":4.2638},{"currency":"forint (Węgry)","code":"HUF","mid":0.01278},{"currency":"frank szwajcarski","code":"CHF","mid":3.8797},{"currency":"funt szterling","code":"GBP","mid":4.9476},{"currency":"hrywna (Ukraina)","code":"UAH","mid":0.1577},{"currency":"jen (Japonia)","code":"JPY","mid":0.035324},{"currency":"korona czeska","code":"CZK","mid":0.1669},{"currency":"korona duńska","code":"DKK","mid":0.5706},{"currency":"korona islandzka","code":"ISK","mid":0.030964},{"currency":"korona norweska","code":"NOK","mid":0.4221},{"currency":"korona szwedzka","code":"SEK","mid":0.3991},{"currency":"kuna (Chorwacja)","code":"HRK","mid":0.5739},{"currency":"lej rumuński","code":"RON","mid":0.8956},{"currency":"lew (Bułgaria)","code":"BGN","mid":2.1800},{"currency":"lira turecka","code":"TRY","mid":0.6713},{"currency":"nowy izraelski szekel","code":"ILS","mid":1.1053},{"currency":"peso chilijskie","code":"CLP","mid":0.005206},{"currency":"peso filipińskie","code":"PHP","mid":0.0764},{"currency":"peso meksykańskie","code":"MXN","mid":0.2012},{"currency":"rand (Republika Południowej Afryki)","code":"ZAR","mid":0.2603},{"currency":"real (Brazylia)","code":"BRL","mid":0.9419},{"currency":"ringgit (Malezja)","code":"MYR","mid":0.9344},{"currency":"rubel rosyjski","code":"RUB","mid":0.0605},{"currency":"rupia indonezyjska","code":"IDR","mid":0.00027562},{"currency":"rupia indyjska","code":"INR","mid":0.054185},{"currency":"won południowokoreański","code":"KRW","mid":0.003337},{"currency":"yuan renminbi (Chiny)","code":"CNY","mid":0.5523},{"currency":"SDR (MFW)","code":"XDR","mid":5.2969}]}]"""
JSON: list = json.loads(RAW_JSON)
@pytest.fixture
def preconfigured_requests_mock():
with mock.patch('nbp_api.requests') as requests_mock:
response_mock = requests_mock.get.return_value
response_mock.status_code = 200
response_mock.json.return_value = JSON
yield requests_mock
def test_should_return_exchange_rate_for_a_working_date(preconfigured_requests_mock):
date = datetime.date(2019, 11, 8)
rate = get_exchange_rate('USD', date)
assert rate == 3.8625
preconfigured_requests_mock.get.assert_called_once_with(
'http://api.nbp.pl/api/exchangerates/tables/a/2019-11-08/?format=json')
def test_should_raise_NoData_for_a_nonworking_date(preconfigured_requests_mock):
response_mock = preconfigured_requests_mock.get.return_value
response_mock.status_code = 404
date = datetime.date(2019, 11, 9)
with pytest.raises(NoData):
get_exchange_rate('USD', date)
preconfigured_requests_mock.get.assert_called_once_with(
'http://api.nbp.pl/api/exchangerates/tables/a/2019-11-09/?format=json')
preconfigured_requests_mock.json.assert_not_called()
@mock.patch('nbp_api.os')
@mock.patch('nbp_api.sys')
def test_should_raise_ValueError_when_currency_not_found(sys_mock, os_mock, preconfigured_requests_mock):
date = datetime.date(2019, 11, 8)
with pytest.raises(ValueError):
get_exchange_rate('XXX', date)
Overwriting nbp_api_tests.py
!pytest nbp_api_tests.py -v
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python cachedir: .pytest_cache rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 3 items nbp_api_tests.py::test_should_return_exchange_rate_for_a_working_date PASSED [ 33%] nbp_api_tests.py::test_should_raise_NoData_for_a_nonworking_date PASSED [ 66%] nbp_api_tests.py::test_should_raise_ValueError_when_currency_not_found PASSED [100%] ============================== 3 passed in 0.18s ===============================
Behaviour-Driven Development (BDD)¶
Behaviour-Driven Development (BDD) is an Agile software development process that encourages collaboration among developers, quality assurance, and non-technical or business participants in a software project. It extends Test-Driven Development (TDD) by writing test cases in natural language that non-programmers can read.
Key Concepts¶
- User Stories: Descriptions of features from the perspective of the end-user.
- Scenarios: Specific situations that demonstrate how a feature should behave.
- Acceptance Tests: Tests written in a way that describes the desired behavior of the system from the user's perspective.
BDD Workflow¶
The BDD process involves several iterative steps:
Planning for BDD¶
Before diving into coding, it's essential to plan out the features and how they will be tested. This includes:
- Identifying user stories and scenarios.
- Planning the structure of your application.
- Deciding on the tools and frameworks you'll use.
BDD Demo: Calculator Application¶
requirements.txt¶
%%writefile calculator/requirements.txt
Flask
WebTest
behave
pytest
Overwriting calculator/requirements.txt
backend.py¶
%%writefile calculator/backend.py
def add(a, b):
return a + b
Overwriting calculator/backend.py
frontend.py¶
%%writefile calculator/frontend.py
from flask import request, Flask
import backend
app = Flask(__name__)
HOME_PAGE = """
<h1>Home Page</h1>
<form method="POST">
<input name="first" /> +
<input name="second" /> =
<input type="submit" value="?" />
</form>
"""
FORM_SENT_TEMPLATE = "{first} + {second} = {sum}"
@app.route('/', methods=['GET', 'POST'])
def home():
if request.method == 'GET':
return HOME_PAGE
else:
first = float(request.form['first'])
second = float(request.form['second'])
sum_ = backend.add(first, second)
return FORM_SENT_TEMPLATE.format(
first=first, second=second, sum=sum_)
Overwriting calculator/frontend.py
test_backend.py¶
%%writefile calculator/test_backend.py
from backend import add
def test_addition():
assert add(50.0, 70.0) == 120.0
Overwriting calculator/test_backend.py
!pytest calculator/test_backend.py
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 1 item calculator/test_backend.py . [100%] ============================== 1 passed in 0.16s ===============================
test_frontend.py¶
%%writefile calculator/test_frontend.py
from unittest import mock
import frontend
@mock.patch('frontend.request', spec=object())
def test_should_display_home_page(request_mock):
request_mock.method = 'GET'
html = frontend.home()
assert 'Home Page' in html
@mock.patch('frontend.request', spec=object())
def test_should_display_addition_form(request_mock):
request_mock.method = 'GET'
html = frontend.home()
assert '<form' in html
assert 'name="first"' in html
assert 'name="second"' in html
assert 'type="submit"' in html
@mock.patch('frontend.request', spec=object())
@mock.patch('frontend.backend')
def test_should_delegate_request_parameters_to_backend(
backend_mock, request_mock):
# arrange (given)
request_mock.method = 'POST'
request_mock.form = {
'first': '50',
'second': '70',
}
backend_mock.add.return_value = 120
# action (when)
html = frontend.home()
# assert (then)
backend_mock.add.assert_called_once_with(50.0, 70.0)
assert '120' in html
Overwriting calculator/test_frontend.py
!pytest calculator/test_frontend.py
============================= test session starts ============================== platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 rootdir: /Users/a563420/python_training/testing plugins: cov-6.0.0, anyio-4.4.0 collected 3 items calculator/test_frontend.py ... [100%] ============================== 3 passed in 0.64s ===============================
features/addition.feature¶
%%writefile calculator/features/addition.feature
Feature: add numbers
In order to avoid silly mistakes
I want to be told the sum of two numbers
Scenario: Navigation to Home Page
When a user navigates to home page
Then Home Page should be displayed
Scenario: Add two numbers
Given a user navigates to home page
And I have entered 50 as first number
And I have entered 70 as second number
When I press add
Then 120 should be displayed
features/steps/addition_steps.py
%%writefile calculator/features/steps/addition_steps.py
from behave import *
@step('a user navigates to home page')
def step_impl(context):
context.resp = context.client.get('/')
@then('{text} should be displayed')
def step_impl(context, text):
assert text in context.resp
@given('I have entered {value} as {field} number')
def step_impl(context, value, field):
context.resp.form[field] = value
@when('I press add')
def step_impl(context):
context.resp = context.resp.form.submit()
features/environment.py
%%writefile calculator/features/environment.py
from webtest import TestApp
from frontend import app
def before_scenario(context, scenario):
context.client = TestApp(app)
def after_scenario(context, scenario):
del context.client
Overwriting calculator/features/environment.py
! cd calculator && behave --no-source
Feature: add numbers In order to avoid silly mistakes I want to be told the sum of two numbers Scenario: Navigation to Home Page When a user navigates to home page When a user navigates to home page # 0.005s Then Home Page should be displayed Then Home Page should be displayed # 0.000s Scenario: Add two numbers Given a user navigates to home page Given a user navigates to home page # 0.001s And I have entered 50 as first number And I have entered 50 as first number # 0.002s And I have entered 70 as second number And I have entered 70 as second number # 0.000s When I press add When I press add # 0.003s Then 120 should be displayed Then 120 should be displayed # 0.000s 1 feature passed, 0 failed, 0 skipped 2 scenarios passed, 0 failed, 0 skipped 7 steps passed, 0 failed, 0 skipped, 0 undefined Took 0m0.012s
Exercise: 🏠 NBP¶
If you're using Python 3.8+, specify the spec
parameter when mocking flask.request
as a workaround for a bug:
@mock.patch(..., spec=object())
NBP API Description¶
Everyone! Experiment with the calculator. Run unit and acceptance tests.
Everyone! Create the directory and file structure for NBP from scratch, using the same structure as the calculator.
BDD for NBP: Change the person after the first acceptance test.
Explanation¶
Mocking
flask.request
in Python 3.8+:- Due to a known bug in Python 3.8 and above, when mocking
flask.request
, it's necessary to specify thespec
parameter to ensure the mock object behaves correctly. This can be done using the@mock.patch
decorator withspec=object()
.
- Due to a known bug in Python 3.8 and above, when mocking
NBP API Description:
Hands-On Practice: Participants are encouraged to engage with the calculator application by running both unit tests and acceptance tests. This hands-on approach reinforces understanding of testing methodologies.
Project Setup: Participants should create the necessary directory and file structure for the NBP (National Bank of Poland) API wrapper from scratch. The structure should mirror that of the calculator application to maintain consistency and leverage established best practices.
Behaviour-Driven Development (BDD) for NBP:
- Team Rotation: After completing the first acceptance test, participants should rotate roles. This ensures that different team members gain experience in various aspects of BDD, fostering a more collaborative and well-rounded development process.
backend.py¶
%%writefile nbp/backend.py
import datetime
import requests
class NoData(Exception):
pass
API_ENDPOINT = 'http://api.nbp.pl/api/exchangerates/tables/a/{}/?format=json'
DATE_FORMAT = '%Y-%m-%d'
def get_exchange_rate(currency, date):
currency = currency.upper()
date_as_str = date.strftime(DATE_FORMAT)
url = API_ENDPOINT.format(date_as_str)
response = requests.get(url)
if response.status_code == 404:
raise NoData
json = response.json()
rates = [rate['mid'] for rate in json[0]['rates']
if rate['code'] == currency]
try:
rate = rates[0]
except IndexError:
raise ValueError("Invalid currency")
else:
return rate
frontend.py¶
%%writefile nbp/frontend.py
from datetime import datetime
from flask import Flask, request
import backend
app = Flask(__name__)
HOME_PAGE_TEMPLATE = """
<h1>Home Page</h1>
<p>{msg}</p>
<form method="POST">
<p>Date: <input type="text" name="date" /></p>
<p>Currency
<select name="currency">
<option value="USD">dolar amerykanski USD</option>
<option value="THB">bat (Tajlandia) THB</option>
<option value="ISK">korona islandzka ISK</option>
</select>
</p>
<p><input type="submit" value="Get exchange rate!"></p>
</form>
"""
EXCHANGE_RATE_TEMPLATE = \
'1.00 {currency} = {rate} PLN'
NO_DATA_MSG = 'No data for this day.'
DATE_FORMAT = '%Y/%m/%d'
@app.route('/', methods=['GET', 'POST'])
def home():
# todo
if __name__ == "__main__":
app.run()
test_backend.py¶
%%writefile nbp/test_backend.py
import datetime
import json
from unittest import mock
import pytest
from backend import get_exchange_rate, NoData
JSON = """[{"table":"A","no":"217/A/NBP/2019","effectiveDate":"2019-11-08","rates":[{"currency":"bat (Tajlandia)","code":"THB","mid":0.1272},{"currency":"dolar amerykański","code":"USD","mid":3.8625},{"currency":"dolar australijski","code":"AUD","mid":2.6533},{"currency":"dolar Hongkongu","code":"HKD","mid":0.4935},{"currency":"dolar kanadyjski","code":"CAD","mid":2.9263},{"currency":"dolar nowozelandzki","code":"NZD","mid":2.4520},{"currency":"dolar singapurski","code":"SGD","mid":2.8406},{"currency":"euro","code":"EUR","mid":4.2638},{"currency":"forint (Węgry)","code":"HUF","mid":0.01278},{"currency":"frank szwajcarski","code":"CHF","mid":3.8797},{"currency":"funt szterling","code":"GBP","mid":4.9476},{"currency":"hrywna (Ukraina)","code":"UAH","mid":0.1577},{"currency":"jen (Japonia)","code":"JPY","mid":0.035324},{"currency":"korona czeska","code":"CZK","mid":0.1669},{"currency":"korona duńska","code":"DKK","mid":0.5706},{"currency":"korona islandzka","code":"ISK","mid":0.030964},{"currency":"korona norweska","code":"NOK","mid":0.4221},{"currency":"korona szwedzka","code":"SEK","mid":0.3991},{"currency":"kuna (Chorwacja)","code":"HRK","mid":0.5739},{"currency":"lej rumuński","code":"RON","mid":0.8956},{"currency":"lew (Bułgaria)","code":"BGN","mid":2.1800},{"currency":"lira turecka","code":"TRY","mid":0.6713},{"currency":"nowy izraelski szekel","code":"ILS","mid":1.1053},{"currency":"peso chilijskie","code":"CLP","mid":0.005206},{"currency":"peso filipińskie","code":"PHP","mid":0.0764},{"currency":"peso meksykańskie","code":"MXN","mid":0.2012},{"currency":"rand (Republika Południowej Afryki)","code":"ZAR","mid":0.2603},{"currency":"real (Brazylia)","code":"BRL","mid":0.9419},{"currency":"ringgit (Malezja)","code":"MYR","mid":0.9344},{"currency":"rubel rosyjski","code":"RUB","mid":0.0605},{"currency":"rupia indonezyjska","code":"IDR","mid":0.00027562},{"currency":"rupia indyjska","code":"INR","mid":0.054185},{"currency":"won południowokoreański","code":"KRW","mid":0.003337},{"currency":"yuan renminbi (Chiny)","code":"CNY","mid":0.5523},{"currency":"SDR (MFW)","code":"XDR","mid":5.2969}]}]"""
def preconfigure_requests_mock(requests_mock):
response = requests_mock.get.return_value
response.status_code = 200
response.json.return_value = json.loads(JSON)
@mock.patch('backend.requests')
def test_should_return_exchange_rate_for_a_working_date(requests_mock):
preconfigure_requests_mock(requests_mock)
date = datetime.date(2019, 11, 8)
rate = get_exchange_rate('USD', date)
assert rate == 3.8625
requests_mock.get.assert_called_once_with(
'http://api.nbp.pl/api/exchangerates/tables/a/2019-11-08/?format=json')
@mock.patch('backend.requests')
def test_should_raise_NoData_for_a_nonworking_date(requests_mock):
response = requests_mock.get.return_value
response.status_code = 404
date = datetime.date(2019, 11, 9)
with pytest.raises(NoData):
get_exchange_rate('USD', date)
requests_mock.get.assert_called_once_with(
'http://api.nbp.pl/api/exchangerates/tables/a/2019-11-09/?format=json')
@mock.patch('backend.requests')
def test_should_raise_ValueError_when_currency_not_found(requests_mock):
preconfigure_requests_mock(requests_mock)
date = datetime.date(2019, 11, 8)
with pytest.raises(ValueError):
get_exchange_rate('XYZ', date)
requests_mock.get.assert_called_once_with(
'http://api.nbp.pl/api/exchangerates/tables/a/2019-11-08/?format=json')
test_frontend.py¶
%%writefile nbp/test_frontend.py
import datetime
from unittest import mock
import frontend
import backend
@mock.patch # todo
class TestHomePage:
def test_home_page_should_have_exchange_rate_form(self, request_mock):
# todo
@mock.patch # todo
def test_home_page_should_work_for_working_days(self, backend_mock, request_mock):
# todo
@mock.patch # todo
def test_home_page_should_work_for_non_working_days(self, get_exchange_rate_mock, request_mock):
# todo
features/exchange_rate.feature¶
%%writefile nbp/features/exchange_rate.feature
Feature: Exchange rate form
In order to become rich
As a future currency speculator
I'd like to know currency exchange rates
Scenario: Navigation to Home Page
When I navigate to Home Page
Then Home Page should be displayed
Scenario Outline: Exchange rate form works
# todo
Examples:
| date | currency | expected output |
| 2017/02/03 | USD | 1.00 USD = 4.0014 PLN |
| 2017/02/03 | ISK | 1.00 ISK = 0.035263 PLN |
| 2017/02/05 | USD | No data for this day |
features/environment.py¶
%%writefile nbp/features/environment.py
from webtest import TestApp
from frontend import app
def before_scenario(context, scenario):
context.client = TestApp(app)
features/steps/exchange_rate_steps.py¶
%%writefile nbp/features/steps/exchange_rate_steps.py
# todo
SOLUTION¶
frontend.py
%%writefile nbp/frontend.py
from datetime import datetime
from flask import Flask, request
import backend
app = Flask(__name__)
HOME_PAGE_TEMPLATE = """
<h1>Home Page</h1>
<p>{msg}</p>
<form method="POST">
<p>Date: <input type="text" name="date" /></p>
<p>Currency
<select name="currency">
<option value="USD">dolar amerykanski USD</option>
<option value="THB">bat (Tajlandia) THB</option>
<option value="ISK">korona islandzka ISK</option>
</select>
</p>
<p><input type="submit" value="Get exchange rate!"></p>
</form>
"""
EXCHANGE_RATE_TEMPLATE = \
'1.00 {currency} = {rate} PLN'
NO_DATA_MSG = 'No data for this day.'
DATE_FORMAT = '%Y/%m/%d'
@app.route('/', methods=['GET', 'POST'])
def home():
if request.method == 'POST':
date_as_str = request.form['date']
date = datetime.strptime(date_as_str, DATE_FORMAT).date()
currency = request.form['currency']
try:
rate = backend.get_exchange_rate(
date=date, currency=currency)
except backend.NoData:
msg = NO_DATA_MSG
else:
msg = EXCHANGE_RATE_TEMPLATE.format(
currency=currency,
rate=rate)
else:
msg = ''
return HOME_PAGE_TEMPLATE.format(msg=msg)
if __name__ == "__main__":
app.run()
test_frontend.py
%%writefile nbp/test_frontend.py
import datetime
from unittest import mock
import frontend
import backend
@mock.patch('frontend.request', spec=object())
class TestHomePage:
def test_home_page_should_have_exchange_rate_form(self, request_mock):
request_mock.method = 'GET'
html = frontend.home()
assert '<form' in html
assert 'name="date"' in html
@mock.patch('frontend.backend')
def test_home_page_should_work_for_working_days(self, backend_mock, request_mock):
request_mock.method = 'POST'
request_mock.form = dict(date='2019/11/08', currency='USD')
backend_mock.get_exchange_rate.return_value = 4.321
html = frontend.home()
assert '1.00 USD = 4.321 PLN' in html
backend_mock.get_exchange_rate.assert_called_once_with(
date=datetime.date(2019, 11, 8),
currency='USD')
@mock.patch('frontend.backend.get_exchange_rate')
def test_home_page_should_work_for_non_working_days(self, get_exchange_rate_mock, request_mock):
request_mock.method = 'POST'
request_mock.form = dict(date='2019/11/09', currency='USD')
get_exchange_rate_mock.side_effect = backend.NoData
html = frontend.home()
assert 'No data for this day.' in html
get_exchange_rate_mock.assert_called_once_with(
date=datetime.date(2019, 11, 9),
currency='USD')
features/exchange_rate.feature
%%writefile nbp/features/exchange_rate.feature
Feature: Exchange rate form
In order to become rich
As a future currency speculator
I'd like to know currency exchange rates
Scenario: Navigation to Home Page
When I navigate to Home Page
Then Home Page should be displayed
Scenario Outline: Exchange rate form works
Given I navigate to Home Page
And I enter <date> as date
And I enter <currency> as currency
When I sent the form
Then <expected output> should be displayed
Examples:
| date | currency | expected output |
| 2017/02/03 | USD | 1.00 USD = 4.0014 PLN |
| 2017/02/03 | ISK | 1.00 ISK = 0.035263 PLN |
| 2017/02/05 | USD | No data for this day |
features/steps/exchange_rate_steps.py
%%writefile nbp/features/steps/exchange_rate_steps.py
from behave import *
# @given('I navigate to Home Page')
# @when('I navigate to Home Page')
@step('I navigate to Home Page')
def step_impl(ctx):
ctx.resp = ctx.client.get('/')
@given('I enter {value} as {field}')
def step_impl(ctx, value, field):
ctx.resp.form[field] = value
@when('I sent the form')
def step_impl(ctx):
ctx.resp = ctx.resp.form.submit()
@then('{text} should be displayed')
def step_impl(ctx, text):
assert text in ctx.resp
Redis¶
This module is designed to provide you with a comprehensive understanding of Redis, from installation to practical usage in Python. By the end of this module, you'll be equipped to leverage Redis in your applications effectively. This module is structured to be completed within half a day, leading up to the start of the twitter.py
exercise.
Redis (Remote Dictionary Server) is an open-source, in-memory data structure store that functions as a versatile database, cache, and message broker. Renowned for its speed and efficiency, Redis is widely used in modern applications to enhance performance, manage real-time data, and facilitate complex data operations. Its ability to handle various data structures with atomic operations makes it a powerful tool for developers seeking both simplicity and performance.
Key Features of Redis¶
Redis boasts a rich set of features that cater to a wide range of use cases. Below are its core functionalities:
In-Memory Data Storage¶
- Speed: By storing data in RAM, Redis provides ultra-fast read and write operations, often completing them in sub-millisecond times.
- Latency: Minimal latency makes Redis ideal for applications requiring real-time data processing and quick response times.
Persistence Options¶
While Redis is primarily an in-memory store, it offers multiple persistence mechanisms to ensure data durability:
- Snapshotting (RDB): Periodically saves the dataset to disk, allowing for point-in-time recovery.
- Append-Only File (AOF): Logs every write operation received by the server, enabling a more granular recovery of data.
- Hybrid Approach: Combines both RDB and AOF for optimal performance and data safety.
Rich Data Structures¶
Redis supports a variety of data structures, each optimized for specific operations:
- Strings: Simple key-value pairs, suitable for caching and simple data storage.
- Hashes: Maps between string fields and string values, ideal for representing objects.
- Lists: Ordered collections of strings, useful for queues and stacks.
- Sets: Unordered collections of unique strings, perfect for membership testing and uniqueness constraints.
- Sorted Sets (ZSets): Similar to sets but with an associated score for each member, enabling ranking and ordering.
- Bitmaps, HyperLogLogs, and Geospatial Indexes: Specialized data structures for advanced use cases like counting unique items, tracking bits, and handling location-based data.
Atomic Operations¶
Redis ensures that all operations are atomic, meaning each operation is completed entirely without interference from other operations. This guarantees data consistency, especially in concurrent environments.
High Availability and Scalability¶
- Replication: Redis supports master-slave replication, allowing data to be duplicated across multiple servers for redundancy and load balancing.
- Sentinel: Provides high availability by monitoring Redis instances and performing automatic failover in case of failures.
- Redis Cluster: Enables horizontal scaling by partitioning data across multiple Redis nodes, ensuring seamless scalability and fault tolerance.
Pub/Sub Messaging¶
Redis includes a Publish/Subscribe messaging paradigm, allowing messages to be broadcast to multiple subscribers. This feature is useful for real-time messaging, event notification systems, and inter-service communication.
Transactions¶
Redis transactions allow multiple commands to be executed in a single, atomic operation. Commands within a transaction are queued and executed sequentially, ensuring that either all commands are processed or none are, maintaining data integrity.
Lightweight and Portable¶
Redis is lightweight, easy to install, and runs on various platforms, including Linux, macOS, and Windows (via WSL or Docker). Its portability makes it accessible for diverse development environments.
Common Use Cases for Redis¶
Redis's flexibility and performance make it suitable for a wide array of applications:
Caching:
- Purpose: Reduce latency and offload database queries by storing frequently accessed data in Redis.
- Example: Caching API responses, session data, or user profiles.
Real-Time Analytics:
- Purpose: Handle high-speed data ingestion and provide immediate insights.
- Example: Tracking website metrics, monitoring application performance, or analyzing user behavior in real-time.
Session Management:
- Purpose: Store user session information efficiently with quick access times.
- Example: Managing user authentication tokens, shopping cart data, or user preferences.
Message Queues:
- Purpose: Implement reliable and scalable message queuing systems for asynchronous processing.
- Example: Task scheduling, background job processing, or inter-service communication in microservices architectures.
Leaderboard and Counting:
- Purpose: Maintain dynamic leaderboards and counters with real-time updates.
- Example: Gaming leaderboards, vote counts, or view counters.
Geospatial Applications:
- Purpose: Store and query location-based data efficiently.
- Example: Finding nearby users, tracking delivery vehicles, or mapping services.
Full-Page Caching:
- Purpose: Cache entire web pages to accelerate web application performance.
- Example: E-commerce websites serving cached product pages to reduce server load.
Rate Limiting:
- Purpose: Control the rate of requests to APIs or services to prevent abuse and ensure fair usage.
- Example: Limiting the number of login attempts or API calls per user within a specific timeframe.
Fundamentals¶
Redis Installation¶
Before diving into Redis, it's essential to install and set it up correctly on your system. Below are the steps to install Redis, particularly focusing on Windows users utilizing the Windows Subsystem for Linux (WSL).
Installation Steps:¶
Set Up WSL (Windows Subsystem for Linux):
If you're on Windows, it's recommended to use WSL for a seamless Redis experience. Follow the official guide to install WSL.
Install Redis:
Open your WSL terminal and execute the following commands:
sudo apt update sudo apt install redis-server
Start Redis Server:
After installation, start the Redis server using:
sudo service redis-server start
Verify Installation:
Test if Redis is running correctly by pinging the server:
redis-cli ping
You should receive a response:
PONG
Install
redis-py
for Python Integration:To interact with Redis using Python, install the
redis
library:pip install redis
Basic Python Connection:¶
Here's a simple Python script to connect to your local Redis server and perform basic operations:
import redis
# Create a connection to the localhost Redis server instance
# By default, Redis runs on port 6379
redis_db = redis.Redis(host="localhost", port=6379, db=0)
# Retrieve all keys (should be empty initially)
print(redis_db.keys()) # Output: []
Explanation:
redis.Redis
: Initializes a connection to the Redis server.host
&port
: Specify the server address and port.db
: Specifies the database number (default is 0).
Redis¶
This module is designed to provide you with a comprehensive understanding of Redis, from installation to practical usage in Python. By the end of this module, you'll be equipped to leverage Redis in your applications effectively. This module is structured to be completed within half a day, leading up to the start of the twitter.py
exercise.
Redis (Remote Dictionary Server) is an open-source, in-memory data structure store that functions as a versatile database, cache, and message broker. Renowned for its speed and efficiency, Redis is widely used in modern applications to enhance performance, manage real-time data, and facilitate complex data operations. Its ability to handle various data structures with atomic operations makes it a powerful tool for developers seeking both simplicity and performance.
Key Features of Redis¶
Redis boasts a rich set of features that cater to a wide range of use cases. Below are its core functionalities:
In-Memory Data Storage¶
- Speed: By storing data in RAM, Redis provides ultra-fast read and write operations, often completing them in sub-millisecond times.
- Latency: Minimal latency makes Redis ideal for applications requiring real-time data processing and quick response times.
Persistence Options¶
While Redis is primarily an in-memory store, it offers multiple persistence mechanisms to ensure data durability:
- Snapshotting (RDB): Periodically saves the dataset to disk, allowing for point-in-time recovery.
- Append-Only File (AOF): Logs every write operation received by the server, enabling a more granular recovery of data.
- Hybrid Approach: Combines both RDB and AOF for optimal performance and data safety.
Rich Data Structures¶
Redis supports a variety of data structures, each optimized for specific operations:
- Strings: Simple key-value pairs, suitable for caching and simple data storage.
- Hashes: Maps between string fields and string values, ideal for representing objects.
- Lists: Ordered collections of strings, useful for queues and stacks.
- Sets: Unordered collections of unique strings, perfect for membership testing and uniqueness constraints.
- Sorted Sets (ZSets): Similar to sets but with an associated score for each member, enabling ranking and ordering.
- Bitmaps, HyperLogLogs, and Geospatial Indexes: Specialized data structures for advanced use cases like counting unique items, tracking bits, and handling location-based data.
Atomic Operations¶
Redis ensures that all operations are atomic, meaning each operation is completed entirely without interference from other operations. This guarantees data consistency, especially in concurrent environments.
High Availability and Scalability¶
- Replication: Redis supports master-slave replication, allowing data to be duplicated across multiple servers for redundancy and load balancing.
- Sentinel: Provides high availability by monitoring Redis instances and performing automatic failover in case of failures.
- Redis Cluster: Enables horizontal scaling by partitioning data across multiple Redis nodes, ensuring seamless scalability and fault tolerance.
Pub/Sub Messaging¶
Redis includes a Publish/Subscribe messaging paradigm, allowing messages to be broadcast to multiple subscribers. This feature is useful for real-time messaging, event notification systems, and inter-service communication.
Transactions¶
Redis transactions allow multiple commands to be executed in a single, atomic operation. Commands within a transaction are queued and executed sequentially, ensuring that either all commands are processed or none are, maintaining data integrity.
Lightweight and Portable¶
Redis is lightweight, easy to install, and runs on various platforms, including Linux, macOS, and Windows (via WSL or Docker). Its portability makes it accessible for diverse development environments.
Common Use Cases for Redis¶
Redis's flexibility and performance make it suitable for a wide array of applications:
Caching:
- Purpose: Reduce latency and offload database queries by storing frequently accessed data in Redis.
- Example: Caching API responses, session data, or user profiles.
Real-Time Analytics:
- Purpose: Handle high-speed data ingestion and provide immediate insights.
- Example: Tracking website metrics, monitoring application performance, or analyzing user behavior in real-time.
Session Management:
- Purpose: Store user session information efficiently with quick access times.
- Example: Managing user authentication tokens, shopping cart data, or user preferences.
Message Queues:
- Purpose: Implement reliable and scalable message queuing systems for asynchronous processing.
- Example: Task scheduling, background job processing, or inter-service communication in microservices architectures.
Leaderboard and Counting:
- Purpose: Maintain dynamic leaderboards and counters with real-time updates.
- Example: Gaming leaderboards, vote counts, or view counters.
Geospatial Applications:
- Purpose: Store and query location-based data efficiently.
- Example: Finding nearby users, tracking delivery vehicles, or mapping services.
Full-Page Caching:
- Purpose: Cache entire web pages to accelerate web application performance.
- Example: E-commerce websites serving cached product pages to reduce server load.
Rate Limiting:
- Purpose: Control the rate of requests to APIs or services to prevent abuse and ensure fair usage.
- Example: Limiting the number of login attempts or API calls per user within a specific timeframe.
Fundamentals¶
Redis Installation¶
Before diving into Redis, it's essential to install and set it up correctly on your system. Below are the steps to install Redis, particularly focusing on Windows users utilizing the Windows Subsystem for Linux (WSL).
Installation Steps:¶
Set Up WSL (Windows Subsystem for Linux):
If you're on Windows, it's recommended to use WSL for a seamless Redis experience. Follow the official guide to install WSL.
Install Redis:
Open your WSL terminal and execute the following commands:
sudo apt update sudo apt install redis-server
Start Redis Server:
After installation, start the Redis server using:
sudo service redis-server start
Verify Installation:
Test if Redis is running correctly by pinging the server:
redis-cli ping
You should receive a response:
PONG
Install
redis-py
for Python Integration:To interact with Redis using Python, install the
redis
library:pip install redis
Basic Python Connection:¶
Here's a simple Python script to connect to your local Redis server and perform basic operations:
import redis
# Create a connection to the localhost Redis server instance
# By default, Redis runs on port 6379
redis_db = redis.Redis(host="localhost", port=6379, db=0)
# Retrieve all keys (should be empty initially)
print(redis_db.keys()) # Output: []
Explanation:
redis.Redis
: Initializes a connection to the Redis server.host
&port
: Specify the server address and port.db
: Specifies the database number (default is 0).
import redis
# Create a connection to the localhost Redis server instance
# By default, Redis runs on port 6379
redis_db = redis.Redis(host="localhost", port=6379, db=0)
# Retrieve all keys (should be empty initially)
print(redis_db.keys()) # Output: []
[b'user:1000', b'user:1000:active', b'tags2', b'tags', b'tags_union', b'unique_visitors', b'hackers']
Using the Redis CLI¶
The Redis Command-Line Interface (CLI) is a powerful tool for interacting with your Redis server directly. Below are common commands and their usages.
Common Redis CLI Commands:¶
redis
127.0.0.1:6379> FLUSHALL
OK
127.0.0.1:6379> KEYS
(error) ERR wrong number of arguments for 'keys' command
127.0.0.1:6379> KEYS *
(empty array)
127.0.0.1:6379> SET db_url "asgdfhdsaghfgghdsa"
OK
127.0.0.1:6379> KEYS *
1) "db_url"
127.0.0.1:6379> GET db_url
"asgdfhdsaghfgghdsa"
127.0.0.1:6379> SET counter 0
OK
127.0.0.1:6379> INCR counter
(integer) 1
127.0.0.1:6379> GET counter
"1"
Explanation of Commands:
FLUSHALL
: Removes all keys from all databases.KEYS *
: Lists all keys in the current database.SET key value
: Sets the value of a key.GET key
: Retrieves the value of a key.INCR key
: Increments the integer value of a key by one.
Understanding Race Conditions¶
Race conditions occur when multiple clients access and modify the same data concurrently, leading to unexpected results. Redis operations are atomic, but understanding potential race conditions is crucial for maintaining data integrity.
Example Without Race Condition:¶
redis
# Initialize the counter
SET counter 0
# Client A retrieves the counter
GET counter # Returns 0
# Client A increments the counter
SET counter 1
# Client B retrieves the updated counter
GET counter # Returns 1
# Client B increments the counter
SET counter 2
Outcome: The counter increments correctly from 0 to 2 without any issues.
Example With Race Condition:¶
redis
# Initialize the counter
SET counter 0
# Client A retrieves the counter
GET counter # Returns 0
# Client B retrieves the counter
GET counter # Also returns 0
# Both clients increment the counter based on the retrieved value
SET counter 1 # Client A
SET counter 1 # Client B
Outcome: Instead of the counter being 2, it remains at 1 due to both clients reading the same initial value and setting it independently.
Mitigating Race Conditions:¶
To prevent race conditions, use Redis's atomic operations such as INCR
:
redis
# Initialize the counter
SET counter 0
# Client A increments the counter
INCR counter # Counter becomes 1
# Client B increments the counter
INCR counter # Counter becomes 2
Outcome: The counter correctly increments to 2, ensuring data integrity.
Interacting with Redis in Python¶
Leveraging Redis within Python applications allows for efficient data storage and retrieval. Below are examples of common operations using the redis-py
library.
Setting Up the Connection:¶
import redis
# Create a connection to the localhost Redis server instance
# By default, Redis runs on port 6379
redis_db = redis.Redis(host="localhost", port=6379, db=0, decode_responses=True)
# Flush all existing data
print(redis_db.flushall()) # Output: True
# Retrieve all keys (should be empty)
print(redis_db.keys()) # Output: []
# Set a key-value pair
print(redis_db.set('db_url', 'asgdfhdsaghfgghdsa')) # Output: True
# Flush all existing data
print(redis_db.flushall()) # Output: True
# Retrieve all keys (should be empty)
print(redis_db.keys()) # Output: []
# Set a key-value pair
print(redis_db.set('db_url', 'asgdfhdsaghfgghdsa')) # Output: True
# Retrieve all keys (should be empty)
print(redis_db.keys()) # Output: []
True [] True [b'db_url']
Performing Basic Operations:¶
# Retrieve all keys
print(redis_db.keys()) # Output: ['db_url']
# Get the value of a key
print(redis_db.get('db_url')) # Output: 'asgdfhdsaghfgghdsa'
# Initialize a counter
redis_db.set('counter', 0)
# Increment the counter
print(redis_db.incr('counter')) # Output: 1
# Get the updated counter value
print(redis_db.get('counter')) # Output: '1'
Explanation:
flushall()
: Clears all data from Redis.keys()
: Retrieves all keys in the current database.set(key, value)
: Sets the value for a specified key.get(key)
: Retrieves the value of a specified key.incr(key)
: Atomically increments the integer value of a key by one.
# Get the value of a key
print(redis_db.get('db_url')) # Output: 'asgdfhdsaghfgghdsa'
b'asgdfhdsaghfgghdsa'
# Initialize a counter
redis_db.set('counter', 0)
# Increment the counter
print(redis_db.incr('counter')) # Output: 1
# Get the updated counter value
print(redis_db.get('counter')) # Output: '1'
1 b'1'
Advanced Operations:¶
To handle more complex scenarios and ensure atomicity, consider using transactions or Redis pipelines.
# Using a pipeline for atomic operations
pipeline = redis_db.pipeline()
pipeline.set('user:1000', 'John Doe')
pipeline.incr('counter')
pipeline.get('user:1000')
pipeline.get('counter')
results = pipeline.execute()
print(results)
# Output: [True, 2, 'John Doe', '2']
pipeline = redis_db.pipeline()
pipeline.set('user:1000', 'John Doe')
pipeline.incr('counter')
pipeline.get('user:1000')
pipeline.get('counter')
results = pipeline.execute()
print(results)
[True, 2, b'John Doe', b'2']
print(redis_db.keys())
[b'user:1000', b'counter', b'db_url']
*Benefits of Pipelines:**
- Performance: Reduces the number of round-trip times by batching commands.
- Atomicity: Ensures that a series of commands are executed as a single transaction.
Redis Data Types¶
Understanding Redis data types is fundamental to leveraging Redis effectively in your applications. Redis supports a variety of data types, each optimized for specific use cases and operations. This section will explore these data types, their associated commands, and practical examples of how to use them both in the Redis CLI and within Python using the redis-py
library.
Overview of Redis Data Types¶
Redis categorizes its data structures into the following types:
- Strings
- Lists
- Hashes
- Sets
- Sorted Sets
- Bitmaps
- HyperLogLogs
Additionally, Redis provides various commands for managing key spaces and other operations such as expiration and flushing databases.
For a comprehensive overview of Redis data types, refer to the Redis Data Types Introduction.
Querying the Key Space¶
Before diving into specific data types, it's essential to understand how to interact with the Redis key space. The key space is where all keys and their associated values are stored.
Common Key Space Commands¶
EXISTS key
: Checks if a key exists.SET key value
: Sets the value of a key.DEL key
: Deletes a key.TYPE key
: Returns the data type of the value stored at the key.
Examples in Redis CLI¶
redis
127.0.0.1:6379> SET greeting "Hello, Redis!"
OK
127.0.0.1:6379> EXISTS greeting
(integer) 1
127.0.0.1:6379> TYPE greeting
string
127.0.0.1:6379> DEL greeting
(integer) 1
127.0.0.1:6379> EXISTS greeting
(integer) 0
Examples in Python¶
import redis
# Connect to Redis
redis_db = redis.Redis(host="localhost", port=6379, db=0, decode_responses=True)
# Set a key
redis_db.set('greeting', 'Hello, Redis!')
# Check if the key exists
print(redis_db.exists('greeting')) # Output: 1
# Get the type of the key
print(redis_db.type('greeting')) # Output: 'string'
# Delete the key
redis_db.delete('greeting')
# Verify deletion
print(redis_db.exists('greeting')) # Output: 0
# Set a key
redis_db.set('greeting', 'Hello, Redis!')
# Check if the key exists
print(redis_db.exists('greeting')) # Output: 1
# Get the type of the key
print(redis_db.type('greeting')) # Output: 'string'
# Delete the key
redis_db.delete('greeting')
# Verify deletion
print(redis_db.exists('greeting')) # Output: 0
1 b'string' 0
Strings¶
Description:
Strings are the simplest Redis data type, representing a sequence of bytes. They can store text, numbers, or binary data.
Common Commands:
SET key value
: Sets the value of a key.GET key
: Retrieves the value of a key.MSET key1 value1 key2 value2 ...
: Sets multiple keys to multiple values.MGET key1 key2 ...
: Retrieves the values of multiple keys.INCR key
: Increments the integer value of a key by one.INCRBY key increment
: Increments the integer value of a key by a specified amount.
Examples in Redis CLI:
redis
127.0.0.1:6379> SET user:1000 "John Doe"
OK
127.0.0.1:6379> GET user:1000
"John Doe"
127.0.0.1:6379> MSET user:1001 "Jane Smith" user:1002 "Alice Johnson"
OK
127.0.0.1:6379> MGET user:1000 user:1001 user:1002
1) "John Doe"
2) "Jane Smith"
3) "Alice Johnson"
127.0.0.1:6379> SET counter 10
OK
127.0.0.1:6379> INCR counter
(integer) 11
127.0.0.1:6379> INCRBY counter 5
(integer) 16
Examples in Python:
# Set a single key
redis_db.set('user:1000', 'John Doe')
# Get the value of a key
print(redis_db.get('user:1000')) # Output: 'John Doe'
# Set multiple keys
redis_db.mset({'user:1001': 'Jane Smith', 'user:1002': 'Alice Johnson'})
# Get multiple keys
print(redis_db.mget('user:1000', 'user:1001', 'user:1002'))
# Output: ['John Doe', 'Jane Smith', 'Alice Johnson']
# Increment a counter
redis_db.set('counter', 10)
redis_db.incr('counter')
print(redis_db.get('counter')) # Output: '11'
# Increment by a specific value
redis_db.incrby('counter', 5)
print(redis_db.get('counter')) # Output: '16'
redis_db.set('user:1000', 'John Doe')
# Get the value of a key
print(redis_db.get('user:1000')) # Output: 'John Doe'
# Set multiple keys
redis_db.mset({'user:1001': 'Jane Smith', 'user:1002': 'Alice Johnson'})
# Get multiple keys
print(redis_db.mget('user:1000', 'user:1001', 'user:1002'))
# Output: ['John Doe', 'Jane Smith', 'Alice Johnson']
# Increment a counter
redis_db.set('counter', 10)
redis_db.incr('counter')
print(redis_db.get('counter')) # Output: '11'
# Increment by a specific value
redis_db.incrby('counter', 5)
print(redis_db.get('counter')) # Output: '16'
b'John Doe' [b'John Doe', b'Jane Smith', b'Alice Johnson'] b'11' b'16'
Lists¶
Description: Lists are ordered collections of strings, implemented as linked lists. They are ideal for implementing queues, stacks, and other ordered data structures.
Common Commands:
RPUSH key value
: Appends a value to the end of a list.LPUSH key value
: Prepends a value to the beginning of a list.RPOP key
: Removes and returns the last element of the list.LPOP key
: Removes and returns the first element of the list.LRANGE key start stop
: Retrieves a range of elements from the list.LLEN key
: Returns the length of the list.LTRIM key start stop
: Trims the list to the specified range.BRPOP key timeout
: Removes and returns the last element of the list, blocking until an element is available or timeout is reached.
Examples in Redis CLI:
redis
127.0.0.1:6379> RPUSH tasks "task1"
(integer) 1
127.0.0.1:6379> RPUSH tasks "task2" "task3"
(integer) 3
127.0.0.1:6379> LPUSH tasks "task0"
(integer) 4
127.0.0.1:6379> LRANGE tasks 0 -1
1) "task0"
2) "task1"
3) "task2"
4) "task3"
127.0.0.1:6379> LLEN tasks
(integer) 4
127.0.0.1:6379> RPOP tasks
"task3"
127.0.0.1:6379> LRANGE tasks 0 -1
1) "task0"
2) "task1"
3) "task2"
Examples in Python:
# Push elements to the right of the list
redis_db.rpush('tasks', 'task1')
redis_db.rpush('tasks', 'task2', 'task3')
# Push an element to the left of the list
redis_db.lpush('tasks', 'task0')
# Retrieve all elements in the list
print(redis_db.lrange('tasks', 0, -1))
# Output: ['task0', 'task1', 'task2', 'task3']
# Get the length of the list
print(redis_db.llen('tasks')) # Output: 4
# Pop an element from the right
print(redis_db.rpop('tasks')) # Output: 'task3'
# Retrieve the updated list
print(redis_db.lrange('tasks', 0, -1))
# Output: ['task0', 'task1', 'task2']
# Push elements to the right of the list
redis_db.rpush('tasks', 'task1')
redis_db.rpush('tasks', 'task2', 'task3')
# Push an element to the left of the list
redis_db.lpush('tasks', 'task0')
# Retrieve all elements in the list
print(redis_db.lrange('tasks', 0, -1))
# Output: ['task0', 'task1', 'task2', 'task3']
# Get the length of the list
print(redis_db.llen('tasks')) # Output: 4
# Pop an element from the right
print(redis_db.rpop('tasks')) # Output: 'task3'
# Retrieve the updated list
print(redis_db.lrange('tasks', 0, -1))
# Output: ['task0', 'task1', 'task2']
[b'task0', b'task0', b'task0', b'task1', b'task2', b'task1', b'task2', b'task1', b'task2', b'task3'] 10 b'task3' [b'task0', b'task0', b'task0', b'task1', b'task2', b'task1', b'task2', b'task1', b'task2']
Hashes¶
Description:
Hashes are maps between string fields and string values, allowing you to store multiple key-value pairs under a single Redis key. They are ideal for representing objects with multiple attributes.
Common Commands:
HSET key field value
: Sets the value of a field in a hash.HGET key field
: Retrieves the value of a field in a hash.HGETALL key
: Retrieves all fields and values in a hash.HMSET key field1 value1 field2 value2 ...
: Sets multiple fields in a hash.HDEL key field
: Deletes a field from a hash.HINCRBY key field increment
: Increments the integer value of a field by a specified amount.
Examples in Redis CLI:
redis
127.0.0.1:6379> HSET user:1000 username "jan"
(integer) 1
127.0.0.1:6379> HSET user:1000 password "1234"
(integer) 1
127.0.0.1:6379> HSET user:1000 birthyear "1977"
(integer) 1
127.0.0.1:6379> HSET user:1000 verified "1"
(integer) 1
127.0.0.1:6379> HGETALL user:1000
1) "username"
2) "jan"
3) "password"
4) "1234"
5) "birthyear"
6) "1977"
7) "verified"
8) "1"
127.0.0.1:6379> HINCRBY user:1000 verified 1
(integer) 2
127.0.0.1:6379> HGET user:1000 verified
"2"
Examples in Python:
# Set fields in a hash
redis_db.hset('user:1000', 'username', 'jan')
redis_db.hset('user:1000', 'password', '1234')
redis_db.hset('user:1000', 'birthyear', '1977')
redis_db.hset('user:1000', 'verified', '1')
# Retrieve all fields and values
print(redis_db.hgetall('user:1000'))
# Output: {'username': 'jan', 'password': '1234', 'birthyear': '1977', 'verified': '1'}
# Increment a field
redis_db.hincrby('user:1000', 'verified', 1)
print(redis_db.hget('user:1000', 'verified')) # Output: '2'
# Set fields in a hash
redis_db.hset('user:1000', 'username', 'jan')
redis_db.hset('user:1000', 'password', '1234')
redis_db.hset('user:1000', 'birthyear', '1977')
redis_db.hset('user:1000', 'verified', '1')
# Retrieve all fields and values
print(redis_db.hgetall('user:1000'))
# Output: {'username': 'jan', 'password': '1234', 'birthyear': '1977', 'verified': '1'}
# Increment a field
redis_db.hincrby('user:1000', 'verified', 1)
print(redis_db.hget('user:1000', 'verified')) # Output: '2'
{b'username': b'jan', b'password': b'1234', b'birthyear': b'1977', b'verified': b'1'} b'2'
Sets¶
Description:
Sets are unordered collections of unique strings. They are useful for membership testing, ensuring uniqueness, and performing set operations like unions and intersections.
Common Commands:
SADD key member1 member2 ...
: Adds one or more members to a set.SREM key member
: Removes a member from a set.SISMEMBER key member
: Checks if a member exists in a set.SMEMBERS key
: Retrieves all members of a set.SUNION key1 key2 ...
: Returns the union of multiple sets.SUNIONSTORE destination key1 key2 ...
: Stores the union of multiple sets in a destination set.SCARD key
: Returns the number of members in a set.
Examples in Redis CLI:
redis
127.0.0.1:6379> SADD tags "python" "redis" "database"
(integer) 3
127.0.0.1:6379> SADD tags "cache" "memory"
(integer) 2
127.0.0.1:6379> SMEMBERS tags
1) "python"
2) "redis"
3) "database"
4) "cache"
5) "memory"
127.0.0.1:6379> SISMEMBER tags "redis"
(integer) 1
127.0.0.1:6379> SREM tags "cache"
(integer) 1
127.0.0.1:6379> SMEMBERS tags
1) "python"
2) "redis"
3) "database"
4) "memory"
127.0.0.1:6379> SCARD tags
(integer) 4
127.0.0.1:6379> SADD tags2 "redis" "scaling"
(integer) 2
127.0.0.1:6379> SUNION tags tags2
1) "python"
2) "redis"
3) "database"
4) "memory"
5) "scaling"
127.0.0.1:6379> SUNIONSTORE tags_union tags tags2
(integer) 5
127.0.0.1:6379> SMEMBERS tags_union
1) "python"
2) "redis"
3) "database"
4) "memory"
5) "scaling"
Examples in Python:
# Add members to a set
redis_db.sadd('tags', 'python', 'redis', 'database')
redis_db.sadd('tags', 'cache', 'memory')
# Retrieve all members
print(redis_db.smembers('tags'))
# Output: {'python', 'redis', 'database', 'cache', 'memory'}
# Check membership
print(redis_db.sismember('tags', 'redis')) # Output: True
print(redis_db.sismember('tags', 'flask')) # Output: False
# Remove a member
redis_db.srem('tags', 'cache')
print(redis_db.smembers('tags'))
# Output: {'python', 'redis', 'database', 'memory'}
# Get the number of members
print(redis_db.scard('tags')) # Output: 4
# Perform set union
redis_db.sadd('tags2', 'redis', 'scaling')
union = redis_db.sunion('tags', 'tags2')
print(union)
# Output: {'python', 'redis', 'database', 'memory', 'scaling'}
# Store the union in a new set
redis_db.sunionstore('tags_union', 'tags', 'tags2')
print(redis_db.smembers('tags_union'))
# Output: {'python', 'redis', 'database', 'memory', 'scaling'}
# Add members to a set
redis_db.sadd('tags', 'python', 'redis', 'database')
redis_db.sadd('tags', 'cache', 'memory')
# Retrieve all members
print(redis_db.smembers('tags'))
# Output: {'python', 'redis', 'database', 'cache', 'memory'}
# Check membership
print(redis_db.sismember('tags', 'redis')) # Output: True
print(redis_db.sismember('tags', 'flask')) # Output: False
# Remove a member
redis_db.srem('tags', 'cache')
print(redis_db.smembers('tags'))
# Output: {'python', 'redis', 'database', 'memory'}
# Get the number of members
print(redis_db.scard('tags')) # Output: 4
# Perform set union
redis_db.sadd('tags2', 'redis', 'scaling')
union = redis_db.sunion('tags', 'tags2')
print(union)
# Output: {'python', 'redis', 'database', 'memory', 'scaling'}
# Store the union in a new set
redis_db.sunionstore('tags_union', 'tags', 'tags2')
print(redis_db.smembers('tags_union'))
# Output: {'python', 'redis', 'database', 'memory', 'scaling'}
{b'redis', b'memory', b'cache', b'python', b'database'} 1 0 {b'memory', b'redis', b'database', b'python'} 4 {b'redis', b'scaling', b'memory', b'python', b'database'} {b'redis', b'scaling', b'memory', b'python', b'database'}
Sorted Sets¶
Description:
Sorted sets are similar to sets but with an associated score for each member. Members are ordered based on their scores, allowing for efficient range queries and ranking.
Common Commands:
ZADD key score1 member1 score2 member2 ...
: Adds one or more members with associated scores to a sorted set.ZRANGE key start stop [WITHSCORES]
: Retrieves a range of members by index.ZREVRANGE key start stop [WITHSCORES]
: Retrieves a range of members by index in reverse order.ZRANGEBYSCORE key min max [WITHSCORES]
: Retrieves members within a score range.ZRANK key member
: Returns the rank of a member.ZREM key member
: Removes a member from a sorted set.ZREMRANGEBYSCORE key min max
: Removes all members within a score range.
Examples in Redis CLI:
redis
127.0.0.1:6379> ZADD hackers 1912 "Alan Turing"
(integer) 1
127.0.0.1:6379> ZADD hackers 1940 "Alan Key"
(integer) 1
127.0.0.1:6379> ZADD hackers 1957 "Sophie Wilson"
(integer) 1
127.0.0.1:6379> ZRANGE hackers 0 -1 WITHSCORES
1) "Alan Turing"
2) "1912"
3) "Alan Key"
4) "1940"
5) "Sophie Wilson"
6) "1957"
127.0.0.1:6379> ZRANK hackers "Alan Key"
(integer) 1
127.0.0.1:6379> ZREM hackers "Alan Turing"
(integer) 1
127.0.0.1:6379> ZRANGE hackers 0 -1 WITHSCORES
1) "Alan Key"
2) "1940"
3) "Sophie Wilson"
4) "1957"
Examples in Python:
# Add members with scores to a sorted set
redis_db.zadd('hackers', {'Alan Turing': 1912, 'Alan Key': 1940, 'Sophie Wilson': 1957})
# Retrieve all members with scores
print(redis_db.zrange('hackers', 0, -1, withscores=True))
# Output: [('Alan Turing', 1912.0), ('Alan Key', 1940.0), ('Sophie Wilson', 1957.0)]
# Get the rank of a member
print(redis_db.zrank('hackers', 'Alan Key')) # Output: 1
# Remove a member
redis_db.zrem('hackers', 'Alan Turing')
print(redis_db.zrange('hackers', 0, -1, withscores=True))
# Output: [('Alan Key', 1940.0), ('Sophie Wilson', 1957.0)]
# Add members with scores to a sorted set
redis_db.zadd('hackers', {'Alan Turing': 1912, 'Alan Key': 1940, 'Sophie Wilson': 1957})
# Retrieve all members with scores
print(redis_db.zrange('hackers', 0, -1, withscores=True))
# Output: [('Alan Turing', 1912.0), ('Alan Key', 1940.0), ('Sophie Wilson', 1957.0)]
# Get the rank of a member
print(redis_db.zrank('hackers', 'Alan Key')) # Output: 1
# Remove a member
redis_db.zrem('hackers', 'Alan Turing')
print(redis_db.zrange('hackers', 0, -1, withscores=True))
# Output: [('Alan Key', 1940.0), ('Sophie Wilson', 1957.0)]
[(b'Alan Turing', 1912.0), (b'Alan Key', 1940.0), (b'Sophie Wilson', 1957.0)] 1 [(b'Alan Key', 1940.0), (b'Sophie Wilson', 1957.0)]
Bitmaps¶
Description:
Bitmaps allow you to manipulate individual bits within a string value. They are useful for scenarios like tracking user activities, feature flags, and storing compact data.
Common Commands:
SETBIT key offset value
: Sets or clears the bit at the specified offset.GETBIT key offset
: Retrieves the bit value at the specified offset.BITCOUNT key [start end]
: Counts the number of set bits (1s) in a string.
Examples in Redis CLI:
redis
127.0.0.1:6379> SETBIT user:1000:active 7 1
(integer) 0
127.0.0.1:6379> GETBIT user:1000:active 7
(integer) 1
127.0.0.1:6379> SETBIT user:1000:active 0 1
(integer) 0
127.0.0.1:6379> GETBIT user:1000:active 0
(integer) 1
127.0.0.1:6379> BITCOUNT user:1000:active
(integer) 2
Examples in Python:
# Set bits at specific offsets
redis_db.setbit('user:1000:active', 7, 1)
redis_db.setbit('user:1000:active', 0, 1)
# Get bits at specific offsets
print(redis_db.getbit('user:1000:active', 7)) # Output: 1
print(redis_db.getbit('user:1000:active', 0)) # Output: 1
# Count the number of set bits
print(redis_db.bitcount('user:1000:active')) # Output: 2
# Set bits at specific offsets
redis_db.setbit('user:1000:active', 7, 1)
redis_db.setbit('user:1000:active', 0, 1)
# Get bits at specific offsets
print(redis_db.getbit('user:1000:active', 7)) # Output: 1
print(redis_db.getbit('user:1000:active', 0)) # Output: 1
# Count the number of set bits
print(redis_db.bitcount('user:1000:active')) # Output: 2
1 1 2
HyperLogLogs¶
Description:
HyperLogLogs are probabilistic data structures used to estimate the cardinality (number of unique elements) of a dataset with minimal memory usage. They are ideal for large-scale unique counts, such as website visitors.
Common Commands:
PFADD key element1 element2 ...
: Adds elements to the HyperLogLog.PFCOUNT key
: Returns the approximate cardinality of the HyperLogLog.
Examples in Redis CLI:
redis
127.0.0.1:6379> PFADD unique_visitors user1
(integer) 1
127.0.0.1:6379> PFADD unique_visitors user2 user3 user4
(integer) 1
127.0.0.1:6379> PFCOUNT unique_visitors
(integer) 4
127.0.0.1:6379> PFADD unique_visitors user2 user5
(integer) 1
127.0.0.1:6379> PFCOUNT unique_visitors
(integer) 5
Examples in Python:
# Add elements to HyperLogLog
redis_db.pfadd('unique_visitors', 'user1')
redis_db.pfadd('unique_visitors', 'user2', 'user3', 'user4')
# Get the approximate count
print(redis_db.pfcount('unique_visitors')) # Output: 4
# Add more elements, including duplicates
redis_db.pfadd('unique_visitors', 'user2', 'user5')
# Get the updated count
print(redis_db.pfcount('unique_visitors')) # Output: 5
# Add elements to HyperLogLog
redis_db.pfadd('unique_visitors', 'user1')
redis_db.pfadd('unique_visitors', 'user2', 'user3', 'user4')
# Get the approximate count
print(redis_db.pfcount('unique_visitors')) # Output: 4
# Add more elements, including duplicates
redis_db.pfadd('unique_visitors', 'user2', 'user5')
# Get the updated count
print(redis_db.pfcount('unique_visitors')) # Output: 5
4 5
Other Essential Commands¶
In addition to data type-specific commands, Redis provides several other commands for managing keys and the database.
Expiration and Time-to-Live (TTL)¶
EXPIRE key seconds
: Sets a timeout on a key.TTL key
: Retrieves the remaining time-to-live of a key.
Examples in Redis CLI:
redis
127.0.0.1:6379> SET session:abc123 "user data"
OK
127.0.0.1:6379> EXPIRE session:abc123 3600
(integer) 1
127.0.0.1:6379> TTL session:abc123
(integer) 3600
127.0.0.1:6379> TTL non_existent_key
(integer) -2
Examples in Python:
# Set a key with expiration
redis_db.set('session:abc123', 'user data')
redis_db.expire('session:abc123', 3600)
# Get the TTL of a key
print(redis_db.ttl('session:abc123')) # Output: 3600
# Get TTL of a non-existent key
print(redis_db.ttl('non_existent_key')) # Output: -2
Flushing the Database¶
FLUSHALL
: Removes all keys from all databases.FLUSHDB
: Removes all keys from the current database.
Examples in Redis CLI:
redis
127.0.0.1:6379> FLUSHDB
OK
127.0.0.1:6379> FLUSHALL
OK
Examples in Python:
# Flush the current database
redis_db.flushdb()
# Flush all databases
redis_db.flushall()
Exercise: 🏠 Build a Redis-Powered Twitter Clone¶
In this exercise, you will implement basic operations of a Twitter-like microblogging platform using Redis as a backend. The operations include:
- Registering a new user: Provide a username and password and store them in Redis.
- Logging in: Validate user credentials and generate an authentication token.
- Following a user: Add a user to the following/follower lists.
- Creating a new post: Insert a new message and propagate it to the followers’ newsfeeds and the global timeline.
- Getting your newsfeed: Retrieve the 10 most recent posts from users you follow.
- Getting the global timeline: Retrieve recent posts from all users, 10 at a time, possibly with pagination.
- Logging out: Invalidate the user’s authentication token.
Data Layout
next_user_id
andnext_post_id
: Integer counters for assigning unique IDs to new users and posts.user:<user_id>
: A hash storingusername
,password
, andauth
(token).users
: A hash mappingusername
touser_id
.followers:<user_id>
: A sorted set of all followers of<user_id>
, sorted by the time they started following.following:<user_id>
: A sorted set of all users that<user_id>
follows, sorted by the time.post:<post_id>
: A hash storinguser_id
,username
,time
, andbody
.timeline
: A global capped list (trimmed to 1000 latest posts).posts:<user_id>
: A list of post IDs for the user’s personal newsfeed.auths
: A hash mapping tokens to user_ids.
Auth Logic
- Register: Check if a user exists by
username
. If not, incrementnext_user_id
, createuser:<id>
, and mapusername -> user_id
. - Login: Given
username
andpassword
, verify credentials. If correct, generate a token, store it inuser:<id>
andauths
. - is_logged_in: Verify that the provided token matches the one stored for the user.
- Logout: Remove the token from
user:<id>
andauths
.
Other Logic
- Create Post: Ensure the user is authenticated. Create a new
post:<id>
hash. Push the post ID onto the author’s and all their followers’posts:<id>
lists, as well as onto the globaltimeline
. - Follow: Ensure authentication. Update
followers:<followed_id>
andfollowing:<follower_id>
with timestamps.
CLI Usage Example¶
Commands such as register username password
, login username password
, post message
, timeline
, newsfeed
, follow username
, logout
, and bye
will interact with Redis through the provided CLI.
plaintext
(Cmd) help register
Creates a new user: REGISTER username password
(Cmd) register jan 1234
(Cmd) post asdf!
:-( Log in first
(Cmd) login jan 123456
:-( Invalid password
(Cmd) login jan 1234
(Cmd) post Cześć, jestem Jan!
(Cmd) post A to jest mĂłj drugi post!
(Cmd) timeline
Fri Jan 15 12:11:48 2021 by jan : A to jest mĂłj drugi post!
Fri Jan 15 12:11:45 2021 by jan : Cześć, jestem Jan!
(Cmd) register ala 1234
(Cmd) login ala 1234
(Cmd) post Cześć, a ja jestem Ala!
(Cmd) timeline
Fri Jan 15 12:12:26 2021 by ala : Cześć, a ja jestem Ala!
Fri Jan 15 12:11:48 2021 by jan : A to jest mĂłj drugi post!
Fri Jan 15 12:11:45 2021 by jan : Cześć, jestem Jan!
(Cmd) newsfeed
Fri Jan 15 12:12:26 2021 by ala : Cześć, a ja jestem Ala!
(Cmd) follow jan
(Cmd) logout
(Cmd) register kamil 12345
(Cmd) login kamil 12345
(Cmd) post A tutaj Kamil
(Cmd) login ala 1234
(Cmd) newsfeed
Fri Jan 15 12:12:26 2021 by ala : Cześć, a ja jestem Ala!
(Cmd) login jan 1234
(Cmd) post O, widzę, Ala, że zaczęłaś mnie follować.
(Cmd) login ala 1234
(Cmd) newsfeed
Fri Jan 15 12:13:56 2021 by jan : O, widzę, Ala, że zaczęłaś mnie follować.
Fri Jan 15 12:12:26 2021 by ala : Cześć, a ja jestem Ala!
(Cmd) timeline
Fri Jan 15 12:13:56 2021 by jan : O, widzę, Ala, że zaczęłaś mnie follować.
Fri Jan 15 12:13:17 2021 by kamil : A tutaj Kamil
Fri Jan 15 12:12:26 2021 by ala : Cześć, a ja jestem Ala!
Fri Jan 15 12:11:48 2021 by jan : A to jest mĂłj drugi post!
Fri Jan 15 12:11:45 2021 by jan : Cześć, jestem Jan!
(Cmd) bye
Bye!
Initial Code¶
You are given the following starting point:
twitter_cli.py
provides a command-line interface using Python’scmd
module.twitter.py
provides data structures and some helper functions but leaves the core logic (login
,require_login
,follow
,create_post
,get_newsfeed
,get_timeline
,logout
) unimplemented.
Your task: Fill in the NotImplementedError
sections in twitter.py
.
twitter_cli.py (Initial Code)¶
# twitter_cli.py
import cmd
from datetime import datetime
import redis
from typing import Optional
import twitter
class TwitterCLI(cmd.Cmd):
def __init__(self, db: redis.Redis):
self.db = db
self.token: Optional[twitter.Token] = None
self.user_id: Optional[twitter.UserId] = None
super().__init__()
def do_register(self, arg: str):
"""Creates a new user: REGISTER username password """
try:
username, password = arg.split()
except ValueError:
print("Invalid usage. Use HELP REGISTER")
else:
twitter.register(self.db, username, password)
def do_login(self, arg: str):
"""Logges in as a user: LOGIN username password """
try:
username, password = arg.split()
except ValueError:
print("Invalid usage. Use HELP LOGIN")
else:
try:
self.token, self.user_id = twitter.login(self.db, username, password)
except twitter.NoSuchUser:
print(":-( No such user")
except twitter.InvalidPassword:
print(":-( Invalid password")
def do_follow(self, username: str):
"""Starts following a user: FOLLOW obama """
if self.token is None:
print(f":-( Log in first")
else:
try:
twitter.follow(self.db, self.user_id, self.token, username)
except twitter.NoSuchUser:
print(f":-( Looks like there is no {username}")
except twitter.NotAuthenticated:
print(f":-( Auth error -- login again")
def do_post(self, body: str):
"""Posts a new update as the logged in user: POST This is your message! """
if self.token is None:
print(f":-( Log in first")
else:
try:
twitter.create_post(self.db, self.user_id, self.token, body)
except twitter.NotAuthenticated:
print(":-( Auth error - login again")
def do_newsfeed(self, body: str):
"""Lists 10 most recent posts on your timeline: NEWSFEED """
if self.token is None:
print(f":-( Log in first")
else:
try:
posts = twitter.get_newsfeed(self.db, self.user_id, self.token)
except twitter.NotAuthenticated:
print(":-( Auth error - login again")
else:
for post in posts:
date = datetime.fromtimestamp(post.time)
print(f"{date.ctime():25} by {post.username:15}: {post.body}")
def do_timeline(self, body: str):
"""Lists 10 most recent posts on the global timeline: TIMELINE.
You can specify the page, so that you can paginate to next 10 posts: TIMELINE 1 """
try:
page = int(body.strip())
except ValueError:
page = 0
if self.token is None:
print(f":-( Log in first")
else:
posts = twitter.get_timeline(self.db, offset=page*10)
for post in posts:
date = datetime.fromtimestamp(post.time)
print(f"{date.ctime():25} by {post.username:15}: {post.body}")
def do_logout(self, body: str):
"""Logouts you. """
if self.token is None:
print(":-( First log in ")
else:
try:
twitter.logout(self.db, self.user_id, self.token)
except twitter.NotAuthenticated:
print(":-( Auth error - login again, then logout")
self.token = None
self.user_id = None
def do_bye(self, arg):
"""Exits """
print("Bye!")
return True
if __name__ == "__main__":
redis_db = redis.Redis(host="localhost", port=6379, db=0, decode_responses=True)
cli = TwitterCLI(redis_db)
cli.cmdloop()
twitter.py (Initial Code - You must complete the logic)¶
# twitter.py - Initial Code
from dataclasses import dataclass, asdict
from datetime import datetime
import random
from typing import NewType, Tuple, List
import redis
TimeStamp = NewType('TimeStamp', int)
Token = NewType('Token', str)
UserId = NewType('UserId', int)
PostId = NewType('PostId', int)
class NoSuchUser(Exception):
pass
class InvalidPassword(Exception):
pass
class NotAuthenticated(Exception):
pass
@dataclass
class User:
username: str
password: str
@dataclass
class Post:
user_id: UserId
username: str
time: TimeStamp
body: str
def current_timestamp() -> int:
return int(datetime.now().timestamp())
def generate_random_token() -> Token:
return str(random.randint(0, 10000000000000))
def register(db: redis.Redis, username: str, password: str) -> None:
user_id = db.incr('next_user_id')
user = User(username=username, password=password)
db.hset(f'user:{user_id}', mapping=asdict(user))
db.hset('users', username, user_id)
def login(db: redis.Redis, username: str, password: str) -> Tuple[Token, UserId]:
""" May raise NoSuchUser or InvalidPassword. """
raise NotImplementedError
def require_login(db: redis.Redis, user_id: UserId, token: Token) -> None:
""" May raise NotAuthenticated. """
raise NotImplementedError
def follow(db: redis.Redis, follower_id: UserId, token: Token, followed_username: str) -> None:
""" May raise NoSuchUser or NotAuthenticated """
require_login(db, follower_id, token)
raise NotImplementedError
def create_post(db: redis.Redis, user_id: UserId, token: Token, body: str) -> None:
""" May raise NotAuthenticated """
require_login(db, user_id, token)
raise NotImplementedError
def get_newsfeed(db: redis.Redis, user_id: UserId, token: Token) -> List[Post]:
require_login(db, user_id, token)
raise NotImplementedError
def get_timeline(db: redis.Redis, offset: int = 0, limit: int = 10) -> List[Post]:
raise NotImplementedError
def logout(db: redis.Redis, user_id: UserId, token: Token) -> None:
""" May raise NotAuthenticated. """
raise NotImplementedError
Solution¶
from dataclasses import dataclass, asdict
from datetime import datetime
import random
from typing import NewType, Tuple, List
import redis
TimeStamp = NewType('TimeStamp', int)
Token = NewType('Token', str)
UserId = NewType('UserId', int)
PostId = NewType('PostId', int)
class NoSuchUser(Exception):
pass
class InvalidPassword(Exception):
pass
class NotAuthenticated(Exception):
pass
@dataclass
class User:
username: str
password: str
@dataclass
class Post:
user_id: UserId
username: str
time: TimeStamp
body: str
def current_timestamp() -> int:
return int(datetime.now().timestamp())
def generate_random_token() -> Token:
return str(random.randint(0, 10000000000000))
def register(db: redis.Redis, username: str, password: str) -> None:
user_id = db.incr('next_user_id')
user = User(username=username, password=password)
db.hset(f'user:{user_id}', mapping=asdict(user))
db.hset('users', username, user_id)
def login(db: redis.Redis, username: str, password: str) -> Tuple[Token, UserId]:
user_id = db.hget('users', username)
if user_id is None:
raise NoSuchUser
correct_password = db.hget(f'user:{user_id}', 'password')
if password != correct_password:
raise InvalidPassword
token = generate_random_token()
db.hset(f'user:{user_id}', 'auth', token)
db.hset('auths', token, user_id)
return token, user_id
def require_login(db: redis.Redis, user_id: UserId, token: Token) -> None:
stored_user_id = db.hget('auths', token)
if stored_user_id is None:
raise NotAuthenticated
real_auth = db.hget(f'user:{user_id}', 'auth')
if real_auth != token:
raise NotAuthenticated
def follow(db: redis.Redis, follower_id: UserId, token: Token, followed_username: str) -> None:
require_login(db, follower_id, token)
followed_user_id = db.hget('users', followed_username)
if followed_user_id is None:
raise NoSuchUser
now = current_timestamp()
db.zadd(f'followers:{followed_user_id}', {follower_id: now})
db.zadd(f'following:{follower_id}', {followed_user_id: now})
def create_post(db: redis.Redis, user_id: UserId, token: Token, body: str) -> None:
require_login(db, user_id, token)
post_id = db.incr('next_post_id')
username = db.hget(f'user:{user_id}', 'username')
post = Post(
user_id=int(user_id),
username=username,
time=current_timestamp(),
body=body,
)
db.hset(f'post:{post_id}', mapping=asdict(post))
# Push to all followers
followers = db.zrange(f'followers:{user_id}', 0, -1)
for fid in followers:
db.lpush(f'posts:{fid}', post_id)
# Also push to author's own feed
db.lpush(f'posts:{user_id}', post_id)
# Push to global timeline
db.lpush('timeline', post_id)
db.ltrim('timeline', 0, 1000)
def _get_posts_by_ids(db: redis.Redis, post_ids: List[PostId]) -> List[Post]:
posts = []
for pid in post_ids:
data = db.hgetall(f'post:{pid}')
posts.append(Post(
user_id=int(data['user_id']),
username=data['username'],
time=int(data['time']),
body=data['body'],
))
return posts
def get_newsfeed(db: redis.Redis, user_id: UserId, token: Token) -> List[Post]:
require_login(db, user_id, token)
post_ids = db.lrange(f'posts:{user_id}', 0, 9) # 10 posts
return _get_posts_by_ids(db, post_ids)
def get_timeline(db: redis.Redis, offset: int = 0, limit: int = 10) -> List[Post]:
post_ids = db.lrange('timeline', offset, offset+limit-1)
return _get_posts_by_ids(db, post_ids)
def logout(db: redis.Redis, user_id: UserId, token: Token) -> None:
require_login(db, user_id, token)
db.hdel(f'user:{user_id}', 'auth')
db.hdel('auths', token)
Publish - Subscribe¶
Redis provides a built-in Publish/Subscribe (Pub/Sub) messaging paradigm. This allows you to broadcast messages to multiple subscribers listening on different channels. Subscribers can listen for new messages as they arrive, enabling real-time notifications, chat systems, streaming logs, and more.
Python: get_message
¶
Below is an example of using Redis Pub/Sub with Python and the redis-py
library. We first create a Redis connection and subscribe to a channel. By calling get_message()
, we can fetch the next available message in a non-blocking manner.
import redis
r = redis.Redis(host="localhost", port=6379, db=0, decode_responses=True)
p = r.pubsub()
p.subscribe('mychannel')
# Initially, we get a subscribe message:
msg = p.get_message()
print(msg)
{'type': 'subscribe', 'pattern': None, 'channel': 'mychannel', 'data': 1}
print(p.get_message()) # Non-blocking attempt to get another message
# None (since no new messages are available)
None
# Publish a message from another client:
r2 = redis.Redis(host="localhost", port=6379, db=0, decode_responses=True)
r2.publish('mychannel', 'my message content')
4
# Now, retrieving the message:
msg = p.get_message()
print(msg)
{'type': 'message', 'pattern': None, 'channel': 'mychannel', 'data': 'my message content'}
Python: listen
¶
If you want to listen continuously to incoming messages, you can use the listen()
method. This method blocks and yields new messages as soon as they arrive.
r3 = redis.Redis(host="localhost", port=6379, db=0, decode_responses=True)
p3 = r3.pubsub()
p3.subscribe('mychannel')
for msg in p3.listen():
print(msg)
{'type': 'subscribe', 'pattern': None, 'channel': 'mychannel', 'data': 1} {'type': 'message', 'pattern': None, 'channel': 'mychannel', 'data': 'message'} {'type': 'message', 'pattern': None, 'channel': 'mychannel', 'data': 'message 1'} {'type': 'message', 'pattern': None, 'channel': 'mychannel', 'data': 'message 2'} {'type': 'message', 'pattern': None, 'channel': 'mychannel', 'data': 'message 2'}
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) Cell In[6], line 5 2 p3 = r3.pubsub() 3 p3.subscribe('mychannel') ----> 5 for msg in p3.listen(): 6 print(msg) File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/client.py:1026, in PubSub.listen(self) 1024 "Listen for messages on channels this client has been subscribed to" 1025 while self.subscribed: -> 1026 response = self.handle_message(self.parse_response(block=True)) 1027 if response is not None: 1028 yield response File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/client.py:865, in PubSub.parse_response(self, block, timeout) 862 conn.connect() 863 return conn.read_response(disconnect_on_error=False, push_request=True) --> 865 response = self._execute(conn, try_read) 867 if self.is_health_check_response(response): 868 # ignore the health check message as user might not expect it 869 self.health_check_response_counter -= 1 File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/client.py:841, in PubSub._execute(self, conn, command, *args, **kwargs) 833 def _execute(self, conn, command, *args, **kwargs): 834 """ 835 Connect manually upon disconnection. If the Redis server is down, 836 this will fail and raise a ConnectionError as desired. (...) 839 patterns we were previously listening to 840 """ --> 841 return conn.retry.call_with_retry( 842 lambda: command(*args, **kwargs), 843 lambda error: self._disconnect_raise_connect(conn, error), 844 ) File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/retry.py:62, in Retry.call_with_retry(self, do, fail) 60 while True: 61 try: ---> 62 return do() 63 except self._supported_errors as error: 64 failures += 1 File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/client.py:842, in PubSub._execute.<locals>.<lambda>() 833 def _execute(self, conn, command, *args, **kwargs): 834 """ 835 Connect manually upon disconnection. If the Redis server is down, 836 this will fail and raise a ConnectionError as desired. (...) 839 patterns we were previously listening to 840 """ 841 return conn.retry.call_with_retry( --> 842 lambda: command(*args, **kwargs), 843 lambda error: self._disconnect_raise_connect(conn, error), 844 ) File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/client.py:863, in PubSub.parse_response.<locals>.try_read() 861 else: 862 conn.connect() --> 863 return conn.read_response(disconnect_on_error=False, push_request=True) File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/connection.py:592, in AbstractConnection.read_response(self, disable_decoding, disconnect_on_error, push_request) 588 response = self._parser.read_response( 589 disable_decoding=disable_decoding, push_request=push_request 590 ) 591 else: --> 592 response = self._parser.read_response(disable_decoding=disable_decoding) 593 except socket.timeout: 594 if disconnect_on_error: File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/_parsers/resp2.py:15, in _RESP2Parser.read_response(self, disable_decoding) 13 pos = self._buffer.get_pos() if self._buffer else None 14 try: ---> 15 result = self._read_response(disable_decoding=disable_decoding) 16 except BaseException: 17 if self._buffer: File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/_parsers/resp2.py:25, in _RESP2Parser._read_response(self, disable_decoding) 24 def _read_response(self, disable_decoding=False): ---> 25 raw = self._buffer.readline() 26 if not raw: 27 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/_parsers/socket.py:115, in SocketBuffer.readline(self) 112 data = buf.readline() 113 while not data.endswith(SYM_CRLF): 114 # there's more data in the socket that we need --> 115 self._read_from_socket() 116 data += buf.readline() 118 return data[:-2] File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/_parsers/socket.py:65, in SocketBuffer._read_from_socket(self, length, timeout, raise_on_timeout) 63 try: 64 while True: ---> 65 data = self._sock.recv(socket_read_size) 66 # an empty string indicates the server shutdown the socket 67 if isinstance(data, bytes) and len(data) == 0: KeyboardInterrupt:
Note: listen()
is blocking. To stop it, you’ll need to interrupt the execution (e.g., KeyboardInterrupt).
Python: Registered Callbacks¶
When subscribing, you can also register callbacks for specific channels. Messages published to those channels will trigger the callback function.
def callback(msg):
print('inside callback', msg)
r4 = redis.Redis(host="localhost", port=6379, db=0, decode_responses=True)
p4 = r4.pubsub()
# Subscribe to 'myanother' and assign a callback to 'mychannel'
p4.subscribe('myanother', myanother=callback)
for msg in p4.listen():
print('inside listen', msg)
inside listen {'type': 'subscribe', 'pattern': None, 'channel': 'myanother', 'data': 1} inside callback {'type': 'message', 'pattern': None, 'channel': 'myanother', 'data': 'message 1'} inside callback {'type': 'message', 'pattern': None, 'channel': 'myanother', 'data': 'message 2'}
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) Cell In[15], line 10 7 # Subscribe to 'myanother' and assign a callback to 'mychannel' 8 p4.subscribe('myanother', myanother=callback) ---> 10 for msg in p4.listen(): 11 print('inside listen', msg) File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/client.py:1026, in PubSub.listen(self) 1024 "Listen for messages on channels this client has been subscribed to" 1025 while self.subscribed: -> 1026 response = self.handle_message(self.parse_response(block=True)) 1027 if response is not None: 1028 yield response File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/client.py:865, in PubSub.parse_response(self, block, timeout) 862 conn.connect() 863 return conn.read_response(disconnect_on_error=False, push_request=True) --> 865 response = self._execute(conn, try_read) 867 if self.is_health_check_response(response): 868 # ignore the health check message as user might not expect it 869 self.health_check_response_counter -= 1 File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/client.py:841, in PubSub._execute(self, conn, command, *args, **kwargs) 833 def _execute(self, conn, command, *args, **kwargs): 834 """ 835 Connect manually upon disconnection. If the Redis server is down, 836 this will fail and raise a ConnectionError as desired. (...) 839 patterns we were previously listening to 840 """ --> 841 return conn.retry.call_with_retry( 842 lambda: command(*args, **kwargs), 843 lambda error: self._disconnect_raise_connect(conn, error), 844 ) File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/retry.py:62, in Retry.call_with_retry(self, do, fail) 60 while True: 61 try: ---> 62 return do() 63 except self._supported_errors as error: 64 failures += 1 File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/client.py:842, in PubSub._execute.<locals>.<lambda>() 833 def _execute(self, conn, command, *args, **kwargs): 834 """ 835 Connect manually upon disconnection. If the Redis server is down, 836 this will fail and raise a ConnectionError as desired. (...) 839 patterns we were previously listening to 840 """ 841 return conn.retry.call_with_retry( --> 842 lambda: command(*args, **kwargs), 843 lambda error: self._disconnect_raise_connect(conn, error), 844 ) File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/client.py:863, in PubSub.parse_response.<locals>.try_read() 861 else: 862 conn.connect() --> 863 return conn.read_response(disconnect_on_error=False, push_request=True) File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/connection.py:592, in AbstractConnection.read_response(self, disable_decoding, disconnect_on_error, push_request) 588 response = self._parser.read_response( 589 disable_decoding=disable_decoding, push_request=push_request 590 ) 591 else: --> 592 response = self._parser.read_response(disable_decoding=disable_decoding) 593 except socket.timeout: 594 if disconnect_on_error: File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/_parsers/resp2.py:15, in _RESP2Parser.read_response(self, disable_decoding) 13 pos = self._buffer.get_pos() if self._buffer else None 14 try: ---> 15 result = self._read_response(disable_decoding=disable_decoding) 16 except BaseException: 17 if self._buffer: File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/_parsers/resp2.py:25, in _RESP2Parser._read_response(self, disable_decoding) 24 def _read_response(self, disable_decoding=False): ---> 25 raw = self._buffer.readline() 26 if not raw: 27 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/_parsers/socket.py:115, in SocketBuffer.readline(self) 112 data = buf.readline() 113 while not data.endswith(SYM_CRLF): 114 # there's more data in the socket that we need --> 115 self._read_from_socket() 116 data += buf.readline() 118 return data[:-2] File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/_parsers/socket.py:65, in SocketBuffer._read_from_socket(self, length, timeout, raise_on_timeout) 63 try: 64 while True: ---> 65 data = self._sock.recv(socket_read_size) 66 # an empty string indicates the server shutdown the socket 67 if isinstance(data, bytes) and len(data) == 0: KeyboardInterrupt:
If the callback is long-running, it will delay processing subsequent messages. You can introduce a delay and observe the blocking behavior.
Python: run_in_thread
¶
To avoid blocking the main thread, you can use run_in_thread()
to handle messages in a separate thread. This makes your main code non-blocking while still processing incoming messages asynchronously.
def callback(msg):
print('inside callback', msg)
sleep(5) # Simulate a long-running task
r5 = redis.Redis(host="localhost", port=6379, db=0, decode_responses=True)
p5 = r5.pubsub()
p5.subscribe(mychannel=callback)
# run_in_thread() will start a separate thread to listen to messages
thread = p5.run_in_thread()
print('This line executes immediately because run_in_thread is non-blocking.')
! add muti threading example
Working with FastAPI¶
In this chapter, we'll explore the core concepts of building APIs with FastAPI, a modern, high-performance, Python web framework for building APIs with Python 3.7+ based on standard Python type hints. By the end of this chapter, you'll understand how to define your API routes, handle parameters and request bodies, validate data, and return responses in a clean, efficient, and Pythonic way.
First Steps¶
Objective: Set up a basic FastAPI application and understand the fundamental building blocks.
- Installation:
pip install fastapi uvicorn
Hello World Example:
from fastapi import FastAPI app = FastAPI() @app.get("/") def read_root(): return {"message": "Hello World"}
- Running the Server:
uvicorn fast:app --reload # fast as the file with server is named fast.py if a file is named main.py then change to main
By visiting http://127.0.0.1:8000/
in your browser, you'll see your API in action. FastAPI also generates interactive documentation at http://127.0.0.1:8000/docs
.
Path Parameters¶
Objective: Learn how to define dynamic URL paths.
- Defining Path Parameters:
@app.get("/items/{item_id}") def read_item(item_id: int): return {"item_id": item_id}
- Type Hints for Validation:
Specifyingitem_id: int
automatically converts and validates the parameter.
Query Parameters¶
Objective: Retrieve optional parameters from the query string.
- Defining Query Parameters:
@app.get("/items") def read_items(skip: int = 0, limit: int = 10): return {"skip": skip, "limit": limit}
- Optional Parameters and Defaults:
If not provided,skip
defaults to0
andlimit
to10
.
Request Body¶
Objective: Send JSON data in the request body and parse it into Python objects.
Pydantic Models:
from pydantic import BaseModel class Item(BaseModel): name: str description: str = None @app.post("/items/") def create_item(item: Item): return {"item": item}
- Automatic Data Parsing and Validation:
FastAPI validates request data and converts it intoItem
instances.
Query Parameters and String Validations¶
Objective: Add validations to query parameters, such as length constraints.
String Validations:
from fastapi import Query @app.get("/users") def read_users(q: str = Query(None, min_length=3, max_length=50)): return {"q": q}
Using
Query
for Metadata and Validation: You can set default values, add regex validations, and provide descriptions.
Path Parameters and Numeric Validations¶
Objective: Validate numeric path parameters with constraints like minimum and maximum values.
Integer Validations:
from fastapi import Path @app.get("/items/{item_id}") def read_items(item_id: int = Path(..., ge=1, le=1000)): return {"item_id": item_id}
Query Parameter Models¶
requires fast api 0.115
Objective: Use pydantic models to handle complex query parameters.
Complex Query Parameters:
from pydantic import BaseModel class UserFilter(BaseModel): name: str age: int @app.get("/search") def search_users(filters: UserFilter): return {"filters": filters}
Dependency Injection for Parsing:
You can also use dependencies to parse complex query parameters into models.
Body - Multiple Parameters¶
Objective: Combine query parameters, path parameters, and request bodies in a single endpoint.
- Mixing Parameters:
@app.post("/orders/{order_id}") def update_order(order_id: int, item: Item, confirm: bool = True): return {"order_id": order_id, "item": item, "confirm": confirm}
Body - Fields¶
Objective: Add metadata and validations to individual fields in a request body model.
Using
Field
:from pydantic import Field class Product(BaseModel): name: str = Field(..., example="Laptop") price: float = Field(..., gt=0, example=999.99)
Body - Nested Models¶
Objective: Create nested data structures and validate them using Pydantic models.
Nested Models:
class Manufacturer(BaseModel): name: str country: str class Product(BaseModel): name: str price: float manufacturer: Manufacturer @app.post("/products/") def create_product(product: Product): return {"product": product}
Declare Request Example Data¶
Objective: Provide examples for request bodies to display in the interactive docs.
Examples in Pydantic Models:
class Item(BaseModel): name: str description: str = None model_config = { "json_schema_extra": { "examples": [ { "name": "Example Item", "description": "A sample item." } ] } }
Header Parameters¶
Objective: Extract headers from the request.
Header Parameter: ```python from fastapi import Header
@app.get("/headers") def read_headers(user_agent: str = Header(None)):
return {"User-Agent": user_agent}
Response Model - Return Type¶
Objective: Define the response model to ensure consistent response formats.
- Response Model:
@app.get("/users/{user_id}", response_model=User) def get_user(user_id: int): return User(id=user_id, name="John Doe")
FastAPI automatically filters and validates the response.
Response Status Code¶
Objective: Set HTTP status codes for your responses.
Specifying Status Codes: ```python from fastapi import status
@app.post("/items/", status_code=status.HTTP_201_CREATED) def create_item(item: Item):
return item
Unit Tests¶
from fastapi import FastAPI
from fastapi.testclient import TestClient
app = FastAPI()
@app.get("/")
async def read_main():
return {"msg": "Hello World"}
client = TestClient(app)
def test_read_main():
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"msg": "Hello World"}