blob: 90d4eb526c04747b6f17daa2ce51c456985521e1 [file] [log] [blame]
# This file is Copyright 2015, Will Bond <will@wbond.net>, licensed to Google
# under the terms of the Google Inbound Service Agreement, signed August 15 2015
def data(provider_method, first_param_name_suffix=False):
"""
A method decorator for unittest.TestCase classes that configured a
static method to be used to provide multiple sets of test data to a single
test
:param provider_method:
The name of the staticmethod of the class to use as the data provider
:param first_param_name_suffix:
If the first parameter for each set should be appended to the method
name to generate the name of the test. Otherwise integers are used.
:return:
The decorated function
"""
def test_func_decorator(test_func):
test_func._provider_method = provider_method
test_func._provider_name_suffix = first_param_name_suffix
return test_func
return test_func_decorator
def data_class(cls):
"""
A class decorator that works with the @provider decorator to generate test
method from a data provider
"""
def generate_test_func(name, original_function, num, params):
if original_function._provider_name_suffix:
data_name = params[0]
params = params[1:]
else:
data_name = num
expanded_name = 'test_%s_%s' % (name, data_name)
# We used expanded variable names here since this line is present in
# backtraces that are generated from test failures.
generated_test_function = lambda self: original_function(self, *params)
setattr(cls, expanded_name, generated_test_function)
for name in dir(cls):
func = getattr(cls, name)
if hasattr(func, '_provider_method'):
num = 1
for params in getattr(cls, func._provider_method)():
generate_test_func(name, func, num, params)
num += 1
return cls