Skip to content

Commit

Permalink
#11 Patch pyspark within python (#14)
Browse files Browse the repository at this point in the history
patch local_connect_and_auth in python rather than the file
  • Loading branch information
abronte authored Jul 15, 2019
1 parent 3d8a186 commit 7a781e0
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 234 deletions.
86 changes: 61 additions & 25 deletions pyspark_gateway/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,67 @@

from pyspark_gateway.server import HTTP_PORT, GATEWAY_PORT

# Function to patch from pyspark
#
# License for below function from pyspark
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
def local_connect_and_auth(port, auth_secret):
from pyspark_gateway import PysparkGateway

tmp_port = PysparkGateway.open_tmp_tunnel(port)

"""
Connect to local host, authenticate with it, and return a (sockfile,sock) for that connection.
Handles IPV4 & IPV6, does some error handling.
:param port
:param auth_secret
:return: a tuple with (sockfile, sock)
"""
sock = None
errors = []
# Support for both IPv4 and IPv6.
# On most of IPv6-ready systems, IPv6 will take precedence.
for res in socket.getaddrinfo(PysparkGateway.host, tmp_port, socket.AF_UNSPEC, socket.SOCK_STREAM):
af, socktype, proto, _, sa = res
try:
sock = socket.socket(af, socktype, proto)
sock.settimeout(15)
sock.connect(sa)
sockfile = sock.makefile("rwb", 65536)
_do_server_auth(sockfile, auth_secret)
return (sockfile, sock)
except socket.error as e:
emsg = _exception_message(e)
errors.append("tried to connect to %s, but an error occured: %s" % (sa, emsg))
sock.close()
sock = None
else:
raise Exception("could not open socket: %s" % errors)
try:
from pyspark import java_gateway
except:
import findspark
findspark.init()

from pyspark import java_gateway

java_gateway.local_connect_and_auth = local_connect_and_auth

class PysparkGateway(object):
host = None
http_url = None
Expand All @@ -22,7 +83,6 @@ def __init__(self,
PysparkGateway.http_url = self.http_url
PysparkGateway.host = self.host

self.patch()
self.check_version()
self.start_gateway()

Expand All @@ -31,30 +91,6 @@ def open_tmp_tunnel(cls, port):
r = requests.post(cls.http_url+'/tmp_tunnel', json={'port': port})
return r.json()['port']

def patch(self):
path = os.path.dirname(os.path.realpath(__file__))+'/patch_files/java_gateway_patch.py'
patch_file = open(path, 'r').read()

pkg = pkgutil.get_loader('pyspark')

path = pkg.get_filename().split('/')[:-1]
path.append('java_gateway.py')
path = '/'.join(path)

if os.path.exists(path+'c'):
os.remove(path+'c')

original_file = open(path, 'r').read()

with open(path, 'w') as f:
f.write(patch_file)

def put_back(data, path):
with open(path, 'w') as f:
f.write(data)

atexit.register(put_back, original_file, path)

def check_version(self):
from pyspark_gateway.spark_version import spark_version, valid_spark_version
from pyspark_gateway.version import __version__
Expand Down
206 changes: 0 additions & 206 deletions pyspark_gateway/patch_files/java_gateway_patch.py

This file was deleted.

4 changes: 1 addition & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
setup(
name='PysparkGateway',
version='0.0.18',
packages=[
'pyspark_gateway',
'pyspark_gateway.patch_files'],
packages=['pyspark_gateway'],
license='Apache 2.0',
description='Connect Pyspark to remote clusters',
long_description=readme,
Expand Down
2 changes: 2 additions & 0 deletions tests/test_pyspark_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import requests
import pandas

import pyspark

from pyspark_gateway import server
from pyspark_gateway import PysparkGateway

Expand Down

0 comments on commit 7a781e0

Please sign in to comment.