diff --git a/httpretty/__init__.py b/httpretty/__init__.py index a752b452..645f8e29 100644 --- a/httpretty/__init__.py +++ b/httpretty/__init__.py @@ -27,12 +27,18 @@ __version__ = version = '0.8.10' -from .core import httpretty, httprettified, EmptyRequestHeaders +from .core import ( + httpretty, httprettified, httprettified_with_overrides, + EmptyRequestHeaders) from .errors import HTTPrettyError, UnmockedError from .core import URIInfo HTTPretty = httpretty activate = httprettified +# this is to ensure backwards compatibility (decorator with and without args) +# it'd be nice to deprecate @httpretty.activate in favor of +# @httpretty.activate() in a future release +activate_with_overrides = httprettified_with_overrides enable = httpretty.enable register_uri = httpretty.register_uri diff --git a/httpretty/core.py b/httpretty/core.py index bcd69b2c..ec370501 100644 --- a/httpretty/core.py +++ b/httpretty/core.py @@ -1019,61 +1019,60 @@ def is_enabled(cls): return cls._is_enabled @classmethod - def enable(cls): + def enable(cls, overrides=None): cls._is_enabled = True # Some versions of python internally shadowed the # SocketType variable incorrectly https://bugs.python.org/issue20386 bad_socket_shadow = (socket.socket != socket.SocketType) - socket.socket = fakesock.socket - socket._socketobject = fakesock.socket - if not bad_socket_shadow: - socket.SocketType = fakesock.socket - - socket.create_connection = create_fake_connection - socket.gethostname = fake_gethostname - socket.gethostbyname = fake_gethostbyname - socket.getaddrinfo = fake_getaddrinfo + _patch(socket, 'socket', fakesock.socket, overrides) + _patch(socket, '_socketobject', fakesock.socket, overrides) - socket.__dict__['socket'] = fakesock.socket - socket.__dict__['_socketobject'] = fakesock.socket if not bad_socket_shadow: - socket.__dict__['SocketType'] = fakesock.socket + _patch(socket, 'SocketType', fakesock.socket, overrides) - socket.__dict__['create_connection'] = create_fake_connection - socket.__dict__['gethostname'] = fake_gethostname - socket.__dict__['gethostbyname'] = fake_gethostbyname - socket.__dict__['getaddrinfo'] = fake_getaddrinfo + _patch(socket, 'create_connection', create_fake_connection, overrides) + _patch(socket, 'gethostname', fake_gethostname, overrides) + _patch(socket, 'gethostbyname', fake_gethostbyname, overrides) + _patch(socket, 'getaddrinfo', fake_getaddrinfo, overrides) if socks: - socks.socksocket = fakesock.socket - socks.__dict__['socksocket'] = fakesock.socket + _patch(socks, 'socksocket', fakesock.socket, overrides) if ssl: - ssl.wrap_socket = fake_wrap_socket - ssl.SSLSocket = FakeSSLSocket - - ssl.__dict__['wrap_socket'] = fake_wrap_socket - ssl.__dict__['SSLSocket'] = FakeSSLSocket + _patch(ssl, 'wrap_socket', fake_wrap_socket, overrides) + _patch(ssl, 'SSLSocket', FakeSSLSocket, overrides) if not PY3: - ssl.sslwrap_simple = fake_wrap_socket - ssl.__dict__['sslwrap_simple'] = fake_wrap_socket + _patch(ssl, 'sslwrap_simple', fake_wrap_socket, overrides) + + +def _patch(module, name, patch, overrides=None): + assert hasattr(module, name) + try: + patch = overrides[module.__name__][name] + except (KeyError, TypeError): + pass + + setattr(module, name, patch) + module.__dict__[name] = patch class httprettized(object): + def __init__(self, overrides=None): + self._overrides = overrides + def __enter__(self): httpretty.reset() - httpretty.enable() + httpretty.enable(self._overrides) def __exit__(self, exc_type, exc_value, traceback): httpretty.disable() httpretty.reset() -def httprettified(test): - "A decorator tests that use HTTPretty" +def wrap_with_overrides(test, overrides=None): def decorate_class(klass): for attr in dir(klass): if not attr.startswith('test_'): @@ -1096,3 +1095,35 @@ def wrapper(*args, **kw): if isinstance(test, ClassTypes): return decorate_class(test) return decorate_callable(test) + + +def httprettified_with_overrides(overrides): + def wrapper(test): + return httprettified(test, overrides=overrides) + return wrapper + + +def httprettified(test, overrides=None): + "A decorator tests that use HTTPretty" + def decorate_class(klass): + for attr in dir(klass): + if not attr.startswith('test_'): + continue + + attr_value = getattr(klass, attr) + if not hasattr(attr_value, "__call__"): + continue + + setattr(klass, attr, decorate_callable(attr_value)) + return klass + + def decorate_callable(test): + @functools.wraps(test) + def wrapper(*args, **kw): + with httprettized(overrides=overrides): + return test(*args, **kw) + return wrapper + + if isinstance(test, ClassTypes): + return decorate_class(test) + return decorate_callable(test) diff --git a/tests/unit/test_core.py b/tests/unit/test_core.py index 9249dbb8..5d0aaf38 100644 --- a/tests/unit/test_core.py +++ b/tests/unit/test_core.py @@ -3,13 +3,16 @@ from __future__ import unicode_literals import json import errno +import socket from datetime import datetime from mock import Mock, patch, call from sure import expect from httpretty.compat import StringIO -from httpretty.core import HTTPrettyRequest, FakeSSLSocket, fakesock, httpretty +from httpretty.core import ( + HTTPrettyRequest, FakeSSLSocket, fakesock, httpretty, + httprettified_with_overrides) class SocketErrorStub(Exception): @@ -601,3 +604,50 @@ def real_sendall(self, data): # Then the entry should have that body httpretty.last_request.body.should.equal(b'BLABLABLABLA') + + +_HOSTNAME = 'eric idle' +_GETHOSTNAME = lambda: _HOSTNAME +_HOSTBYNAME= '127.0.0.1' +_GETHOSTBYNAME = lambda _: _GETHOSTBYNAME + +# effectively bypassing the mock +_GETADDRINFO = socket.getaddrinfo + + +overrides = { + 'socket': { + 'gethostname': _GETHOSTNAME, + 'gethostbyname': _GETHOSTBYNAME, + 'getaddrinfo': _GETADDRINFO, + }, +} + + +def _test_patching(): + expect(socket.gethostname()).to.equal(_HOSTNAME) + expect(socket.gethostname).to.equal(_GETHOSTNAME) + expect(socket.gethostbyname( + 'localhost')).to.equal(_GETHOSTBYNAME('localhost')) + expect(socket.gethostbyname).to.equal(_GETHOSTBYNAME) + expect(socket.getaddrinfo).to.equal(_GETADDRINFO) + expect(socket.getaddrinfo( + 'localhost', 80)).to.equal(_GETADDRINFO('localhost', 80)) + + +def test_explicit_overrides(): + httpretty.enable(overrides) + _test_patching() + httpretty.disable() + + +# test function decorator +@httprettified_with_overrides(overrides) +def test_function_decorator_overrides(): + _test_patching() + + +@httprettified_with_overrides(overrides) +class TestClass(object): + def test_class_decorator_overrides(self): + _test_patching()