Skip to content

Commit

Permalink
BXC-4676 - Service for transferring files to remote destination (#106)
Browse files Browse the repository at this point in the history
* Add SshClientService, which pulls out the functionality for running scp and ssh commands into their own class

* Add service for transferring source files to a remote destination, as well as create any necessary directories. Updates TestSshServer to work with key authentication, and to support both scp and ssh commands
  • Loading branch information
bbpennel committed Sep 4, 2024
1 parent 3746798 commit b0f5fb5
Show file tree
Hide file tree
Showing 9 changed files with 513 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,13 @@
import edu.unc.lib.boxc.migration.cdm.model.CdmEnvironment;
import edu.unc.lib.boxc.migration.cdm.model.MigrationProject;
import edu.unc.lib.boxc.migration.cdm.services.ChompbConfigService.ChompbConfig;
import org.apache.sshd.client.SshClient;
import org.apache.sshd.common.SshException;
import org.apache.sshd.common.config.keys.FilePasswordProvider;
import edu.unc.lib.boxc.migration.cdm.util.SshClientService;
import org.apache.sshd.scp.client.ScpClient;
import org.apache.sshd.scp.client.ScpClientCreator;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;

/**
Expand Down Expand Up @@ -89,27 +85,14 @@ public static Path getExportedCpdsPath(MigrationProject project) {
* @param downloadBlock method containing download operations
*/
public void executeDownloadBlock(Consumer<ScpClient> downloadBlock) {
SshClient client = SshClient.setUpDefaultClient();
client.setFilePasswordProvider(FilePasswordProvider.of(sshPassword));
client.start();
var cdmEnvId = project.getProjectProperties().getCdmEnvironmentId();
var cdmEnvConfig = chompbConfig.getCdmEnvironments().get(cdmEnvId);
try (var sshSession = client.connect(sshUsername, cdmEnvConfig.getSshHost(), cdmEnvConfig.getSshPort())
.verify(SSH_TIMEOUT_SECONDS, TimeUnit.SECONDS)
.getSession()) {
sshSession.addPasswordIdentity(sshPassword);

sshSession.auth().verify(SSH_TIMEOUT_SECONDS, TimeUnit.SECONDS);
var cdmEnvConfig = getCdmEnvironment();
var sshService = new SshClientService();
sshService.setSshHost(cdmEnvConfig.getSshHost());
sshService.setSshPort(cdmEnvConfig.getSshPort());
sshService.setSshUsername(sshUsername);
sshService.setSshPassword(sshPassword);

var scpClientCreator = ScpClientCreator.instance();
var scpClient = scpClientCreator.createScpClient(sshSession);
downloadBlock.accept(scpClient);
} catch (IOException e) {
if (e instanceof SshException && e.getMessage().contains("No more authentication methods available")) {
throw new MigrationException("Authentication to server failed, check username or password");
}
throw new MigrationException("Failed to establish ssh session", e);
}
sshService.executeScpBlock(downloadBlock);
}

private CdmEnvironment getCdmEnvironment() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package edu.unc.lib.boxc.migration.cdm.services;

import edu.unc.lib.boxc.migration.cdm.model.SourceFilesInfo;
import edu.unc.lib.boxc.migration.cdm.util.SshClientService;

import java.io.IOException;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.stream.Collectors;

/**
* Service for transferring source files from the local server to a remote destination.
* @author bbpennel
*/
public class SourceFilesToRemoteService {
private SourceFileService sourceFileService;
private SshClientService sshClientService;
private int concurrentTransfers = 5;

/**
* Transfer files from the source CDM server to the remote destination.
* Files are transferred in parallel, up to concurrentTransfers at a time.
* @param destinationPath
* @throws IOException
*/
public void transferFiles(Path destinationPath) throws IOException {
var sourceMappings = sourceFileService.loadMappings();
final Path destinationBasePath = destinationPath.toAbsolutePath();
// Get all the source paths as a thread safe queue
var sourcePaths = sourceMappings.getMappings().stream()
.map(SourceFilesInfo.SourceFileMapping::getFirstSourcePath)
.collect(Collectors.toList());
var pathsDeque = new ConcurrentLinkedDeque<Path>(sourcePaths);
// For tracking if a parent directory has already been created
Set<String> createdParentsSet = ConcurrentHashMap.newKeySet();
// Create the remote destination directory
sshClientService.executeRemoteCommand("mkdir -p " + destinationBasePath);
createdParentsSet.add(destinationBasePath.toString());

var threads = new ArrayList<Thread>(concurrentTransfers);
// Start threads for parallel transfer of files
for (int i = 0; i < concurrentTransfers; i++) {
var thread = createTransferThread(pathsDeque, destinationBasePath, createdParentsSet);
thread.start();
threads.add(thread);
}

// Wait for all threads to finish
threads.forEach(t -> {
try {
t.join();
} catch (InterruptedException e) {
throw new RuntimeException("Thread interrupted", e);
}
});
}

private Thread createTransferThread(ConcurrentLinkedDeque<Path> pathsDeque,
Path destinationBasePath,
Set<String> createdParentsSet) {
var thread = new Thread(() -> {
Path nextPath;
while ((nextPath = pathsDeque.poll()) != null) {
final Path sourcePath = nextPath;
sshClientService.executeSshBlock((sshClient) -> {
var sourceRelative = sourcePath.toAbsolutePath().toString().substring(1);
var destPath = destinationBasePath.resolve(sourceRelative);
var destParentPath = destPath.getParent();
// Create the parent path if we haven't already done so
synchronized (createdParentsSet) {
if (!createdParentsSet.contains(destParentPath.toString())) {
createdParentsSet.add(destParentPath.toString());
sshClientService.executeRemoteCommand("mkdir -p " + destPath.getParent());
}
}
// Upload the file to the appropriate path on the remote server
sshClientService.executeScpBlock(sshClient, (scpClient) -> {
try {
scpClient.upload(sourcePath.toString(), destPath.toString());
} catch (IOException e) {
throw new RuntimeException("Failed to transfer file " + sourcePath, e);
}
});
});
}
});
return thread;
}

public void setSourceFileService(SourceFileService sourceFileService) {
this.sourceFileService = sourceFileService;
}

public void setSshClientService(SshClientService sshClientService) {
this.sshClientService = sshClientService;
}

public void setConcurrentTransfers(int concurrentTransfers) {
this.concurrentTransfers = concurrentTransfers;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
package edu.unc.lib.boxc.migration.cdm.util;

import edu.unc.lib.boxc.migration.cdm.exceptions.MigrationException;
import org.apache.sshd.client.SshClient;
import org.apache.sshd.client.channel.ClientChannel;
import org.apache.sshd.client.channel.ClientChannelEvent;
import org.apache.sshd.client.session.ClientSession;
import org.apache.sshd.common.SshException;
import org.apache.sshd.common.config.keys.FilePasswordProvider;
import org.apache.sshd.common.keyprovider.KeyPairProvider;
import org.apache.sshd.common.util.io.resource.PathResource;
import org.apache.sshd.scp.client.ScpClient;
import org.apache.sshd.scp.client.ScpClientCreator;
import org.apache.sshd.common.util.security.SecurityUtils;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.security.GeneralSecurityException;
import java.security.KeyPair;
import java.util.EnumSet;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;

import static java.util.Collections.singletonList;

/**
* Service for executing remote commands and transfers
* @author bbpennel
*/
public class SshClientService {
private static final int SSH_TIMEOUT_SECONDS = 10;

private String sshHost;
private int sshPort;
private String sshUsername;
private String sshPassword;
private Path sshKeyPath;
private KeyPair sshKeyPair;

public void initialize() {
if (sshKeyPath != null) {
try {
sshKeyPair = SecurityUtils.loadKeyPairIdentities(
null, new PathResource(sshKeyPath), Files.newInputStream(sshKeyPath), null
).iterator().next();
} catch (IOException | GeneralSecurityException e) {
throw new MigrationException("Failed to load ssh key", e);
}
}
}

private SshClient buildSshClient() {
SshClient client = SshClient.setUpDefaultClient();
if (sshKeyPair != null) {
client.setKeyIdentityProvider(KeyPairProvider.wrap(singletonList(sshKeyPair)));
} else if (sshPassword != null) {
client.setFilePasswordProvider(FilePasswordProvider.of(sshPassword));
}
return client;
}

private void setupSessionAuthentication(ClientSession session) {
if (sshKeyPair != null) {
session.addPublicKeyIdentity(sshKeyPair);
} else if (sshPassword != null) {
session.addPasswordIdentity(sshPassword);
}
}

/**
* Execute a remote command on the server
* @param command
* @return Response output from the command
*/
public String executeRemoteCommand(String command) {
var response = new AtomicReference<String>();
executeSshBlock(clientSession -> {
response.set(executeRemoteCommand(clientSession, command));
});
return response.get();
}

/**
* Execute a remote command on the server, using the provided session
* @param command
* @return Response output from the command
*/
public String executeRemoteCommand(ClientSession clientSession, String command) {
try (var responseStream = new ByteArrayOutputStream();
ClientChannel channel = clientSession.createExecChannel(command)) {

channel.setOut(responseStream);
channel.setErr(responseStream);
channel.open().verify(5, TimeUnit.SECONDS);

channel.waitFor(EnumSet.of(ClientChannelEvent.CLOSED), 5000);
return responseStream.toString();
} catch (Exception e) {
throw new MigrationException("Failed to execute remote command", e);
}
}

/**
* Execute a block of code with an SSH session
* @param sshBlock
*/
public void executeSshBlock(Consumer<ClientSession> sshBlock) {
SshClient client = buildSshClient();
client.start();
try (var sshSession = client.connect(sshUsername, sshHost, sshPort)
.verify(SSH_TIMEOUT_SECONDS, TimeUnit.SECONDS)
.getSession()) {
setupSessionAuthentication(sshSession);
sshSession.auth().verify(SSH_TIMEOUT_SECONDS, TimeUnit.SECONDS);
sshBlock.accept(sshSession);
} catch (IOException e) {
if (e instanceof SshException && e.getMessage().contains("No more authentication methods available")) {
throw new MigrationException("Authentication to server failed, check username or password", e);
}
throw new MigrationException("Failed to establish ssh session", e);
}
}

/**
* Execute a block of code that requires an SCP client
* @param scpBlock
*/
public void executeScpBlock(Consumer<ScpClient> scpBlock) {
executeSshBlock(client -> {
executeScpBlock(client, scpBlock);
});
}

/**
* Execute a block of code that requires an SCP client, using the provided ssh session
* @param session
* @param scpBlock
*/
public void executeScpBlock(ClientSession session, Consumer<ScpClient> scpBlock) {
var scpClientCreator = ScpClientCreator.instance();
var scpClient = scpClientCreator.createScpClient(session);
scpBlock.accept(scpClient);
}

public void setSshHost(String sshHost) {
this.sshHost = sshHost;
}

public void setSshPort(int sshPort) {
this.sshPort = sshPort;
}

public void setSshUsername(String sshUsername) {
this.sshUsername = sshUsername;
}

public void setSshPassword(String sshPassword) {
this.sshPassword = sshPassword;
}

public void setSshKeyPath(Path sshKeyPath) {
this.sshKeyPath = sshKeyPath;
}
}
Loading

0 comments on commit b0f5fb5

Please sign in to comment.