diff --git a/docs/changelog-fragments/661.bugfix.rst b/docs/changelog-fragments/661.bugfix.rst new file mode 100644 index 000000000..85a366f81 --- /dev/null +++ b/docs/changelog-fragments/661.bugfix.rst @@ -0,0 +1 @@ +Uploading large files over SCP no longer fails -- by :user:`Jakuje`. diff --git a/src/pylibsshext/scp.pyx b/src/pylibsshext/scp.pyx index 6589ff431..1b918bdd7 100644 --- a/src/pylibsshext/scp.pyx +++ b/src/pylibsshext/scp.pyx @@ -74,15 +74,25 @@ cdef class SCP: ) try: + # Read buffer + read_buffer_size = min(file_size, SCP_MAX_CHUNK) + # Begin to send to the file rc = libssh.ssh_scp_push_file(scp, filename_b, file_size, file_mode) if rc != libssh.SSH_OK: raise LibsshSCPException("Can't open remote file: %s" % self._get_ssh_error_str()) - # Write to the open file - rc = libssh.ssh_scp_write(scp, PyBytes_AS_STRING(f.read()), file_size) - if rc != libssh.SSH_OK: - raise LibsshSCPException("Can't write to remote file: %s" % self._get_ssh_error_str()) + remaining_bytes_to_read = file_size + while remaining_bytes_to_read > 0: + # Read the chunk from local file + read_bytes = min(remaining_bytes_to_read, read_buffer_size) + read_buffer = f.read(read_bytes) + + # Write to the open file + rc = libssh.ssh_scp_write(scp, PyBytes_AS_STRING(read_buffer), read_bytes) + if rc != libssh.SSH_OK: + raise LibsshSCPException("Can't write to remote file: %s" % self._get_ssh_error_str()) + remaining_bytes_to_read -= read_bytes finally: libssh.ssh_scp_close(scp) libssh.ssh_scp_free(scp) diff --git a/tests/unit/scp_test.py b/tests/unit/scp_test.py index dd2d1b9f9..81558950e 100644 --- a/tests/unit/scp_test.py +++ b/tests/unit/scp_test.py @@ -120,3 +120,9 @@ def test_get_large(dst_path, src_path_large, ssh_scp, large_payload): """Check that SCP file download gets over 64kB of data.""" ssh_scp.get(str(src_path_large), str(dst_path)) assert dst_path.read_bytes() == large_payload + + +def test_put_large(dst_path, src_path_large, ssh_scp, large_payload): + """Check that SCP file download gets over 64kB of data.""" + ssh_scp.put(str(src_path_large), str(dst_path)) + assert dst_path.read_bytes() == large_payload