在Python中,枚举和我们在对象中定义的类变量时一样的,每一个类变量就是一个枚举项,访问枚举项的方式为:类名加上类变量。
虽然这样是可以解决问题的,但是并不严谨,也不怎么安全,比如:
1、枚举类中,不应该存在key相同的枚举项(类变量)
2、不允许在类外直接修改枚举项的值
下面来看看如何实现枚举类型:
import operator
class EnumValue(object):
def __init__(self, parent_name, name, value):
self._parent_name = parent_name
self._name = name
self._value = value
def _parents_equal(self, other):
return (
hasattr(other, '_parent_name')
and self._parent_name == other._parent_name)
def _check_parents_equal(self, other):
if not self._parents_equal(other):
raise TypeError(
'This operation is valid only for enum values of the same type')
def __eq__(self, other):
return self._parents_equal(other) and self._value == other._value
def __ne__(self, other):
return not self.__eq__(other)
def __lt__(self, other):
self._check_parents_equal(other)
return self._value < other._value
def __le__(self, other):
self._check_parents_equal(other)
return self._value <= other._value
def __gt__(self, other):
self._check_parents_equal(other)
return self._value > other._value
def __ge__(self, other):
self._check_parents_equal(other)
return self._value >= other._value
def __hash__(self):
return hash(self._parent_name + str(self._value))
def __repr__(self):
return '{}({!r}, {!r}, {!r})'.format(
self.__class__.__name__, self._parent_name, self._name, self._value)
def __int__(self):
return int(self._value)
def __str__(self):
return str(self._name)
class EnumMetaclass(type):
def __new__(cls, name, bases, dct):
uppercased = dict((k.upper(), v) for k, v in dct.items())
new_dct = dict(
name=name,
_enums_by_str=dict(
(k, EnumValue(name, k, v)) for k, v in uppercased.items()),
_enums_by_int=dict(
(v, EnumValue(name, k, v)) for k, v in uppercased.items()),
)
return super(EnumMetaclass, cls).__new__(cls, name, bases, new_dct)
def __getattr__(cls, name):
try:
return cls.__getitem__(name)
except KeyError:
raise AttributeError
def __getitem__(cls, name):
try:
name = name.upper()
except AttributeError:
pass
try:
return cls._enums_by_str[name]
except KeyError:
return cls._enums_by_int[name]
def __repr__(cls):
return '{}({!r}, {})'.format(
cls.__class__.__name__,
cls.name,
', '.join('{}={}'.format(v._name, v._value)
for v in sorted(cls._enums_by_str.values())))
def values(cls):
return sorted(cls._enums_by_str.values())
def _values_comparison(cls, item, comparison_operator):
"""
Return a list of values such that comparison_operator(value, item) is
True.
"""
return sorted(
[v for v in cls._enums_by_str.values()
if comparison_operator(v, item)])
def values_lt(cls, item):
return cls._values_comparison(item, operator.lt)
def values_le(cls, item):
return cls._values_comparison(item, operator.le)
def values_gt(cls, item):
return cls._values_comparison(item, operator.gt)
def values_ge(cls, item):
return cls._values_comparison(item, operator.ge)
def values_ne(cls, item):
return cls._values_comparison(item, operator.ne)
def enum_factory(name, **kwargs):
return EnumMetaclass(name, (), kwargs)
Python中枚举的测试:
import unittest
class EnumTestCase(unittest.TestCase):
def test_repr(self):
ProfileAction = enum_factory(
'ProfileAction', VIEW=1, EDIT_OWN=2, EDIT_PUBLIC=3, EDIT_FULL=4)
self.assertEqual(
repr(ProfileAction),
"EnumMetaclass('ProfileAction', VIEW=1, EDIT_OWN=2, EDIT_PUBLIC=3, EDIT_FULL=4)")
def test_value_repr(self):
ProfileAction = enum_factory(
'ProfileAction', VIEW=1, EDIT_OWN=2, EDIT_PUBLIC=3, EDIT_FULL=4)
self.assertEqual(
repr(ProfileAction.VIEW), "EnumValue('ProfileAction', 'VIEW', 1)")
def test_attribute_error(self):
ProfileAction = enum_factory(
'ProfileAction', VIEW=1, EDIT_OWN=2, EDIT_PUBLIC=3, EDIT_FULL=4)
with self.assertRaises(AttributeError):
ProfileAction.ASDFASDF
def test_cast_to_str(self):
ProfileAction = enum_factory(
'ProfileAction', VIEW=1, EDIT_OWN=2, EDIT_PUBLIC=3, EDIT_FULL=4)
self.assertEqual(str(ProfileAction.VIEW), 'VIEW')
def test_cast_to_int(self):
ProfileAction = enum_factory(
'ProfileAction', VIEW=1, EDIT_OWN=2, EDIT_PUBLIC=3, EDIT_FULL=4)
self.assertEqual(int(ProfileAction.VIEW), 1)
def test_access_by_str(self):
ProfileAction = enum_factory(
'ProfileAction', VIEW=1, EDIT_OWN=2, EDIT_PUBLIC=3, EDIT_FULL=4)
self.assertEqual(ProfileAction['VIEW'], ProfileAction.VIEW)
def test_access_by_int(self):
ProfileAction = enum_factory(
'ProfileAction', VIEW=1, EDIT_OWN=2, EDIT_PUBLIC=3, EDIT_FULL=4)
self.assertEqual(ProfileAction[1], ProfileAction.VIEW)
def test_equality(self):
ProfileAction = enum_factory(
'ProfileAction', VIEW=1, EDIT_OWN=2, EDIT_PUBLIC=3, EDIT_FULL=4)
self.assertEqual(ProfileAction.VIEW, ProfileAction.VIEW)
self.assertEqual(ProfileAction['VIEW'], ProfileAction.VIEW)
self.assertEqual(ProfileAction[1], ProfileAction.VIEW)
def test_inequality(self):
ProfileAction = enum_factory(
'ProfileAction', VIEW=1, EDIT_OWN=2, EDIT_PUBLIC=3, EDIT_FULL=4)
self.assertNotEqual(ProfileAction.VIEW, ProfileAction.EDIT_OWN)
self.assertNotEqual(ProfileAction['VIEW'], ProfileAction.EDIT_OWN)
self.assertNotEqual(ProfileAction[1], ProfileAction.EDIT_OWN)
DashboardAction = enum_factory(
'DashboardAction', VIEW=1, EDIT_OWN=2, EDIT_PUBLIC=3, EDIT_FULL=4)
self.assertNotEqual(ProfileAction.VIEW, DashboardAction.VIEW)
def test_invalid_comparison(self):
ProfileAction = enum_factory(
'ProfileAction', VIEW=1, EDIT_OWN=2, EDIT_PUBLIC=3, EDIT_FULL=4)
DashboardAction = enum_factory(
'DashboardAction', VIEW=1, EDIT_OWN=2, EDIT_PUBLIC=3, EDIT_FULL=4)
with self.assertRaises(TypeError) as cm:
ProfileAction.VIEW < DashboardAction.EDIT_OWN
self.assertEqual(
str(cm.exception),
'This operation is valid only for enum values of the same type')
def test_values(self):
ProfileAction = enum_factory(
'ProfileAction', VIEW=1, EDIT_OWN=2, EDIT_PUBLIC=3, EDIT_FULL=4)
self.assertEqual(
ProfileAction.values(), [
EnumValue('ProfileAction', 'VIEW', 1),
EnumValue('ProfileAction', 'EDIT_OWN', 2),
EnumValue('ProfileAction', 'EDIT_PUBLIC', 3),
EnumValue('ProfileAction', 'EDIT_FULL', 4),
])
def test_values_lt(self):
ProfileAction = enum_factory(
'ProfileAction', VIEW=1, EDIT_OWN=2, EDIT_PUBLIC=3, EDIT_FULL=4)
self.assertEqual(
ProfileAction.values_lt(ProfileAction.EDIT_PUBLIC), [
EnumValue('ProfileAction', 'VIEW', 1),
EnumValue('ProfileAction', 'EDIT_OWN', 2),
])
def test_values_le(self):
ProfileAction = enum_factory(
'ProfileAction', VIEW=1, EDIT_OWN=2, EDIT_PUBLIC=3, EDIT_FULL=4)
self.assertEqual(
ProfileAction.values_le(ProfileAction.EDIT_PUBLIC), [
EnumValue('ProfileAction', 'VIEW', 1),
EnumValue('ProfileAction', 'EDIT_OWN', 2),
EnumValue('ProfileAction', 'EDIT_PUBLIC', 3),
])
def test_values_gt(self):
ProfileAction = enum_factory(
'ProfileAction', VIEW=1, EDIT_OWN=2, EDIT_PUBLIC=3, EDIT_FULL=4)
self.assertEqual(
ProfileAction.values_gt(ProfileAction.EDIT_PUBLIC), [
EnumValue('ProfileAction', 'EDIT_FULL', 4),
])
def test_values_ge(self):
ProfileAction = enum_factory(
'ProfileAction', VIEW=1, EDIT_OWN=2, EDIT_PUBLIC=3, EDIT_FULL=4)
self.assertEqual(
ProfileAction.values_ge(ProfileAction.EDIT_PUBLIC), [
EnumValue('ProfileAction', 'EDIT_PUBLIC', 3),
EnumValue('ProfileAction', 'EDIT_FULL', 4),
])
def test_values_ne(self):
ProfileAction = enum_factory(
'ProfileAction', VIEW=1, EDIT_OWN=2, EDIT_PUBLIC=3, EDIT_FULL=4)
self.assertEqual(
ProfileAction.values_ne(ProfileAction.EDIT_PUBLIC), [
EnumValue('ProfileAction', 'VIEW', 1),
EnumValue('ProfileAction', 'EDIT_OWN', 2),
EnumValue('ProfileAction', 'EDIT_FULL', 4),
])
def test_intersection_with_same_type(self):
ProfileAction = enum_factory(
'ProfileAction', VIEW=1, EDIT_OWN=2, EDIT_PUBLIC=3, EDIT_FULL=4)
set_a = set([ProfileAction.VIEW, ProfileAction.EDIT_OWN])
set_b = set([ProfileAction.VIEW, ProfileAction.EDIT_PUBLIC])
self.assertEqual(set_a & set_b, set([ProfileAction.VIEW]))
def test_intersection_with_different_types(self):
ProfileAction = enum_factory(
'ProfileAction', VIEW=1, EDIT_OWN=2, EDIT_PUBLIC=3, EDIT_FULL=4)
DashboardAction = enum_factory(
'DashboardAction', VIEW=1, EDIT_OWN=2, EDIT_PUBLIC=3, EDIT_FULL=4)
set_a = set([ProfileAction.VIEW, ProfileAction.EDIT_OWN])
set_b = set([DashboardAction.VIEW, DashboardAction.EDIT_PUBLIC])
self.assertEqual(set_a & set_b, set([]))