# pylint: disable=no-member
"""The SSHDriver uses SSH as a transport to implement CommandProtocol and FileTransferProtocol"""
import logging
import os
import shutil
import subprocess
import tempfile
import attr
from ..factory import target_factory
from ..protocol import CommandProtocol, FileTransferProtocol
from .commandmixin import CommandMixin
from .common import Driver
from ..step import step
from .exception import ExecutionError
from ..util.proxy import proxymanager
[docs]@target_factory.reg_driver
@attr.s(eq=False)
class SSHDriver(CommandMixin, Driver, CommandProtocol, FileTransferProtocol):
"""SSHDriver - Driver to execute commands via SSH"""
bindings = {"networkservice": "NetworkService", }
priorities = {CommandProtocol: 10, FileTransferProtocol: 10}
keyfile = attr.ib(default="", validator=attr.validators.instance_of(str))
stderr_merge = attr.ib(default=False, validator=attr.validators.instance_of(bool))
[docs] def __attrs_post_init__(self):
super().__attrs_post_init__()
self.logger = logging.getLogger("{}({})".format(self, self.target))
self._keepalive = None
[docs] def on_activate(self):
self.ssh_prefix = ["-o", "LogLevel=ERROR"]
if self.keyfile:
keyfile_path = self.keyfile
if self.target.env:
keyfile_path = self.target.env.config.resolve_path(self.keyfile)
self.ssh_prefix += ["-i", keyfile_path ]
if not self.networkservice.password:
self.ssh_prefix += ["-o", "PasswordAuthentication=no"]
self.control = self._start_own_master()
self.ssh_prefix += ["-F", "none"]
if self.control:
self.ssh_prefix += ["-o", "ControlPath={}".format(self.control.replace('%', '%%'))]
self._keepalive = None
self._start_keepalive();
[docs] def on_deactivate(self):
try:
self._stop_keepalive()
finally:
self._cleanup_own_master()
def _start_own_master(self):
"""Starts a controlmaster connection in a temporary directory."""
timeout = 30
self.tmpdir = tempfile.mkdtemp(prefix='labgrid-ssh-tmp-')
control = os.path.join(
self.tmpdir, 'control-{}'.format(self.networkservice.address)
)
# use sshpass if we have a password
args = ["sshpass", "-e"] if self.networkservice.password else []
args += ["ssh", "-f", *self.ssh_prefix, "-x", "-o", "ConnectTimeout={}".format(timeout),
"-o", "ControlPersist=300", "-o",
"UserKnownHostsFile=/dev/null", "-o", "StrictHostKeyChecking=no",
"-o", "ServerAliveInterval=15", "-MN", "-S", control.replace('%', '%%'), "-p",
str(self.networkservice.port), "-l", self.networkservice.username,
self.networkservice.address]
# proxy via the exporter if we have an ifname suffix
address = self.networkservice.address
if address.count('%') > 1:
raise ValueError("Multiple '%' found in '{}'.".format(address))
if '%' in address:
address, ifname = address.split('%', 1)
else:
ifname = None
proxy_cmd = proxymanager.get_command(self.networkservice, address, self.networkservice.port, ifname)
if proxy_cmd: # only proxy if needed
args += [
"-o", "ProxyCommand={} 2>{}".format(
' '.join(proxy_cmd),
self.tmpdir+'/proxy-stderr',
)
]
env = os.environ.copy()
if self.networkservice.password:
env['SSHPASS'] = self.networkservice.password
self.process = subprocess.Popen(args, env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
stdin=subprocess.DEVNULL)
try:
subprocess_timeout = timeout + 5
return_value = self.process.wait(timeout=subprocess_timeout)
if return_value != 0:
stdout = self.process.stdout.readlines()
for line in stdout:
self.logger.warning("ssh: %s", line.rstrip().decode(encoding="utf-8", errors="replace"))
try:
proxy_error = open(self.tmpdir+'/proxy-stderr').read().strip()
if proxy_error:
raise ExecutionError(
"Failed to connect to {} with {}: error from SSH ProxyCommand: {}".
format(self.networkservice.address, " ".join(args), proxy_error),
stdout=stdout,
)
except FileNotFoundError:
pass
raise ExecutionError(
"Failed to connect to {} with {}: return code {}".
format(self.networkservice.address, " ".join(args), return_value),
stdout=stdout,
)
except subprocess.TimeoutExpired:
raise ExecutionError(
"Subprocess timed out [{}s] while executing {}".
format(subprocess_timeout, args),
)
if not os.path.exists(control):
raise ExecutionError(
"no control socket to {}".format(self.networkservice.address)
)
self.logger.info('Connected to %s', self.networkservice.address)
return control
[docs] @Driver.check_active
@step(args=['cmd'], result=True)
def run(self, cmd, codec="utf-8", decodeerrors="strict", timeout=None): # pylint: disable=unused-argument
return self._run(cmd, codec=codec, decodeerrors=decodeerrors)
def _run(self, cmd, codec="utf-8", decodeerrors="strict", timeout=None): # pylint: disable=unused-argument
"""Execute `cmd` on the target.
This method runs the specified `cmd` as a command on its target.
It uses the ssh shell command to run the command and parses the exitcode.
cmd - command to be run on the target
returns:
(stdout, stderr, returncode)
"""
if not self._check_keepalive():
raise ExecutionError("Keepalive no longer running")
complete_cmd = ["ssh", "-x", *self.ssh_prefix,
"-p", str(self.networkservice.port), "-l", self.networkservice.username,
self.networkservice.address
] + cmd.split(" ")
self.logger.debug("Sending command: %s", complete_cmd)
if self.stderr_merge:
stderr_pipe = subprocess.STDOUT
else:
stderr_pipe = subprocess.PIPE
try:
sub = subprocess.Popen(
complete_cmd, stdout=subprocess.PIPE, stderr=stderr_pipe
)
except:
raise ExecutionError(
"error executing command: {}".format(complete_cmd)
)
stdout, stderr = sub.communicate(timeout=timeout)
stdout = stdout.decode(codec, decodeerrors).split('\n')
if stdout[-1] == '':
stdout.pop()
if stderr is None:
stderr = []
else:
stderr = stderr.decode(codec, decodeerrors).split('\n')
stderr.pop()
return (stdout, stderr, sub.returncode)
[docs] def interact(self, cmd=None):
assert cmd is None or isinstance(cmd, list)
if not self._check_keepalive():
raise ExecutionError("Keepalive no longer running")
complete_cmd = ["ssh", "-x", *self.ssh_prefix,
"-t",
self.networkservice.address
]
if cmd:
complete_cmd += ["--", *cmd]
self.logger.debug("Running command: %s", complete_cmd)
sub = subprocess.Popen(
complete_cmd,
)
return sub.wait()
[docs] @Driver.check_active
@step(args=['src', 'dst'])
def scp(self, *, src, dst):
if not self._check_keepalive():
raise ExecutionError("Keepalive no longer running")
if src.startswith(':') == dst.startswith(':'):
raise ValueError("Either source or destination must be remote (start with :)")
if src.startswith(':'):
src = '_' + src
if dst.startswith(':'):
dst = '_' + dst
complete_cmd = ["scp",
"-F", "none",
"-o", "ControlPath={}".format(self.control.replace('%', '%%')),
src, dst,
]
self.logger.info("Running command: %s", complete_cmd)
sub = subprocess.Popen(
complete_cmd,
)
return sub.wait()
[docs] @Driver.check_active
@step(args=['src', 'dst', 'extra'])
def rsync(self, *, src, dst, extra=[]):
if not self._check_keepalive():
raise ExecutionError("Keepalive no longer running")
if src.startswith(':') == dst.startswith(':'):
raise ValueError("Either source or destination must be remote (start with :)")
if src.startswith(':'):
src = '_' + src
if dst.startswith(':'):
dst = '_' + dst
ssh_cmd = ["ssh",
"-F", "none",
"-o", "ControlPath={}".format(self.control.replace('%', '%%')),
]
complete_cmd = ["rsync",
"-v",
"--rsh={}".format(' '.join(ssh_cmd)),
"-rlpt", # --recursive --links --perms --times
"--one-file-system",
"--progress",
*extra,
src, dst,
]
self.logger.info("Running command: %s", complete_cmd)
sub = subprocess.Popen(
complete_cmd,
)
return sub.wait()
[docs] @Driver.check_active
@step(args=['path', 'mountpoint'])
def sshfs(self, *, path, mountpoint):
if not self._check_keepalive():
raise ExecutionError("Keepalive no longer running")
complete_cmd = ["sshfs",
"-F", "none",
"-f",
"-o", "ControlPath={}".format(self.control.replace('%', '%%')),
":{}".format(path),
mountpoint,
]
self.logger.debug("Running command: %s", complete_cmd)
sub = subprocess.Popen(
complete_cmd,
)
try:
sub.wait(1)
raise ExecutionError(
"error executing command: {}".format(complete_cmd)
)
except subprocess.TimeoutExpired: # still running
self.logger.info("Started SSHFS on {}. Press CTRL-C to stop.".format(mountpoint))
sub.wait()
[docs] def get_status(self):
"""The SSHDriver is always connected, return 1"""
return 1
[docs] @Driver.check_active
@step(args=['filename', 'remotepath'])
def put(self, filename, remotepath=''):
transfer_cmd = [
"scp",
*self.ssh_prefix,
"-P", str(self.networkservice.port),
"-r",
filename,
"{user}@{host}:{remotepath}".format(
user=self.networkservice.username,
host=self.networkservice.address,
remotepath=remotepath)
]
try:
sub = subprocess.call(
transfer_cmd
) #, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
except:
raise ExecutionError(
"error executing command: {}".format(transfer_cmd)
)
if sub != 0:
raise ExecutionError(
"error executing command: {}".format(transfer_cmd)
)
[docs] @Driver.check_active
@step(args=['filename', 'destination'])
def get(self, filename, destination="."):
transfer_cmd = [
"scp",
*self.ssh_prefix,
"-P", str(self.networkservice.port),
"-r",
"{user}@{host}:{filename}".format(
user=self.networkservice.username,
host=self.networkservice.address,
filename=filename),
destination
]
try:
sub = subprocess.call(
transfer_cmd
) #, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
except:
raise ExecutionError(
"error executing command: {}".format(transfer_cmd)
)
if sub != 0:
raise ExecutionError(
"error executing command: {}".format(transfer_cmd)
)
def _cleanup_own_master(self):
"""Exit the controlmaster and delete the tmpdir"""
complete_cmd = "ssh -x -o ControlPath={cpath} -O exit -p {port} -l {user} {host}".format(
cpath=self.control,
port=self.networkservice.port,
user=self.networkservice.username,
host=self.networkservice.address
).split(' ')
res = subprocess.call(
complete_cmd,
stdin=subprocess.DEVNULL,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL
)
if res != 0:
self.logger.info("Socket already closed")
shutil.rmtree(self.tmpdir)
def _start_keepalive(self):
"""Starts a keepalive connection via the own or external master."""
args = ["ssh", *self.ssh_prefix, "cat"]
assert self._keepalive is None
self._keepalive = subprocess.Popen(
args,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
self.logger.debug('Started keepalive for %s', self.networkservice.address)
def _check_keepalive(self):
return self._keepalive.poll() is None
def _stop_keepalive(self):
assert self._keepalive is not None
self.logger.debug('Stopping keepalive for %s', self.networkservice.address)
try:
self._keepalive.communicate(timeout=60)
except subprocess.TimeoutExpired:
self._keepalive.kill()
try:
self._keepalive.wait(timeout=60)
finally:
self._keepalive = None