Skip to content

Commit

Permalink
feat(smb): improved smb code for file downloads
Browse files Browse the repository at this point in the history
If connections are made too closely together sometimes windows will refuse a new connection. this
update will retry connections.
  • Loading branch information
christopherpickering committed Aug 29, 2023
1 parent a697ffc commit 7f27b1c
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 4 deletions.
14 changes: 11 additions & 3 deletions runner/scripts/em_smb.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,15 @@ def _walk(
new_path = str(Path(directory).joinpath(dirname))
yield from self._walk(new_path)

def __load_file(self, file_name: str) -> IO[str]:
def __load_file(self, file_name: str, index: int, length: int) -> IO[str]:
RunnerLog(
self.task, self.run_id, 10, f"({index} of {length}) downloading {file_name}"
)

director = urllib.request.build_opener(SMBHandler)

password = em_decrypt(self.password, app.config["PASS_KEY"])

open_file_for_read = director.open(
f"smb://{self.username}:{password}@{self.server_name},{self.server_ip}/{self.share_name}/{file_name}"
)
Expand Down Expand Up @@ -279,9 +284,12 @@ def read(self, file_name: str) -> List[IO[str]]:
)

# if a file was found, try to open.
return [self.__load_file(file_name) for file_name in file_list]
return [
self.__load_file(file_name, i, len(file_list))
for i, file_name in enumerate(file_list, 1)
]

return [self.__load_file(file_name)]
return [self.__load_file(file_name, 1, 1)]
except BaseException as e:
raise RunnerException(
self.task,
Expand Down
25 changes: 24 additions & 1 deletion runner/scripts/smb_fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import socket
import sys
import tempfile
import time
import urllib.error
import urllib.parse
import urllib.request
Expand All @@ -29,6 +30,7 @@
from urllib.response import addinfourl

from nmb.NetBIOS import NetBIOS
from smb.base import NotConnectedError, SMBTimeout
from smb.SMBConnection import SMBConnection

USE_NTLM = True
Expand Down Expand Up @@ -90,7 +92,28 @@ def smb_open(self, req):
conn = SMBConnection(
user, passwd, myname, server_name, domain=domain, use_ntlm_v2=USE_NTLM
)
conn.connect(host, port)

# retry
retry = 0
while True:
try:
connected = conn.connect(host, port, timeout=120)
if not connected:
raise AssertionError()
break

except (
AssertionError,
ConnectionResetError,
SMBTimeout,
NotConnectedError,
) as e:
if retry <= 10:
retry += 1
time.sleep(5) # wait 5 sec before retrying
continue

raise ValueError(f"Connection failed.\n{e}")

headers = email.message.Message()
if req.data:
Expand Down

0 comments on commit 7f27b1c

Please sign in to comment.