Source code for labgrid.driver.sshdriver

"""The SSHDriver uses SSH as a transport to implement CommandProtocol and FileTransferProtocol"""
import contextlib
import logging
import os
import re
import stat
import shlex
import shutil
import subprocess
import tempfile
import time

import attr

from ..factory import target_factory
from ..protocol import CommandProtocol, FileTransferProtocol
from .commandmixin import CommandMixin
from .common import Driver
from ..step import step
from .exception import ExecutionError
from ..util.helper import get_free_port
from ..util.proxy import proxymanager
from ..util.timeout import Timeout
from ..util.ssh import get_ssh_connect_timeout


[docs]@target_factory.reg_driver @attr.s(eq=False) class SSHDriver(CommandMixin, Driver, CommandProtocol, FileTransferProtocol): """SSHDriver - Driver to execute commands via SSH""" bindings = {"networkservice": "NetworkService", } priorities = {CommandProtocol: 10, FileTransferProtocol: 10} keyfile = attr.ib(default="", validator=attr.validators.instance_of(str)) stderr_merge = attr.ib(default=False, validator=attr.validators.instance_of(bool)) connection_timeout = attr.ib(default=float(get_ssh_connect_timeout()), validator=attr.validators.instance_of(float)) explicit_sftp_mode = attr.ib(default=False, validator=attr.validators.instance_of(bool))
[docs] def __attrs_post_init__(self): super().__attrs_post_init__() self.logger = logging.getLogger(f"{self}({self.target})") self._keepalive = None
[docs] def on_activate(self): self.ssh_prefix = ["-o", "LogLevel=ERROR"] if self.keyfile: keyfile_path = self.keyfile if self.target.env: keyfile_path = self.target.env.config.resolve_path(self.keyfile) self.ssh_prefix += ["-i", keyfile_path ] if not self.networkservice.password: self.ssh_prefix += ["-o", "PasswordAuthentication=no"] self.control = self._start_own_master() self.ssh_prefix += ["-F", "none"] if self.control: self.ssh_prefix += ["-o", f"ControlPath={self.control.replace('%', '%%')}"] self._keepalive = None self._start_keepalive()
[docs] def on_deactivate(self): try: self._stop_keepalive() finally: self._cleanup_own_master()
@property def skip_deactivate_on_export(self): # We need to keep the connection to the target open. return True def _start_own_master(self): """Starts a controlmaster connection in a temporary directory.""" timeout = Timeout(self.connection_timeout) # Retry start of controlmaster, to allow handle failures such as # connection refused during target startup connect_timeout = round(timeout.remaining) while True: if connect_timeout == 0: raise Exception("Timeout while waiting for ssh connection") try: return self._start_own_master_once(connect_timeout) except ExecutionError as e: if timeout.expired: raise e time.sleep(0.5) connect_timeout = round(timeout.remaining) def _start_own_master_once(self, timeout): self.tmpdir = tempfile.mkdtemp(prefix='lg-ssh-') control = os.path.join( self.tmpdir, f'control-{self.networkservice.address}' ) args = ["ssh", "-f", *self.ssh_prefix, "-x", "-o", f"ConnectTimeout={timeout}", "-o", "ControlPersist=300", "-o", "UserKnownHostsFile=/dev/null", "-o", "StrictHostKeyChecking=no", "-o", "ServerAliveInterval=15", "-MN", "-S", control.replace('%', '%%'), "-p", str(self.networkservice.port), "-l", self.networkservice.username, self.networkservice.address] # proxy via the exporter if we have an ifname suffix address = self.networkservice.address if address.count('%') > 1: raise ValueError(f"Multiple '%' found in '{address}'.") if '%' in address: address, ifname = address.split('%', 1) else: ifname = None proxy_cmd = proxymanager.get_command(self.networkservice, address, self.networkservice.port, ifname) if proxy_cmd: # only proxy if needed args += [ "-o", f"ProxyCommand={' '.join(proxy_cmd)} 2>{self.tmpdir}/proxy-stderr" ] env = os.environ.copy() pass_file = '' if self.networkservice.password: fd, pass_file = tempfile.mkstemp() os.fchmod(fd, stat.S_IRWXU) #with openssh>=8.4 SSH_ASKPASS_REQUIRE can be used to force SSH_ASK_PASS #openssh<8.4 requires the DISPLAY var and a detached process with start_new_session=True env = {'SSH_ASKPASS': pass_file, 'DISPLAY':'', 'SSH_ASKPASS_REQUIRE':'force'} with open(fd, 'w') as f: f.write("#!/bin/sh\necho " + shlex.quote(self.networkservice.password)) self.process = subprocess.Popen(args, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, stdin=subprocess.DEVNULL, start_new_session=True) try: subprocess_timeout = timeout + 5 return_value = self.process.wait(timeout=subprocess_timeout) if return_value != 0: stdout, _ = self.process.communicate(timeout=subprocess_timeout) stdout = stdout.split(b"\n") for line in stdout: self.logger.warning("ssh: %s", line.rstrip().decode(encoding="utf-8", errors="replace")) try: with open(f'{self.tmpdir}/proxy-stderr') as proxy_err_fd: proxy_error = proxy_err_fd.read().strip() if proxy_error: raise ExecutionError( f"Failed to connect to {self.networkservice.address} with {' '.join(args)}: error from SSH ProxyCommand: {proxy_error}", # pylint: disable=line-too-long stdout=stdout, ) except FileNotFoundError: pass raise ExecutionError( f"Failed to connect to {self.networkservice.address} with {' '.join(args)}: return code {return_value}", # pylint: disable=line-too-long stdout=stdout, ) except subprocess.TimeoutExpired: raise ExecutionError( f"Subprocess timed out [{subprocess_timeout}s] while executing {args}", ) finally: if self.networkservice.password and os.path.exists(pass_file): os.remove(pass_file) if not os.path.exists(control): raise ExecutionError( f"no control socket to {self.networkservice.address}" ) self.logger.info('Connected to %s', self.networkservice.address) return control
[docs] @Driver.check_active @step(args=['cmd'], result=True) def run(self, cmd, codec="utf-8", decodeerrors="strict", timeout=None): return self._run(cmd, codec=codec, decodeerrors=decodeerrors, timeout=timeout)
def _run(self, cmd, codec="utf-8", decodeerrors="strict", timeout=None): """Execute `cmd` on the target. This method runs the specified `cmd` as a command on its target. It uses the ssh shell command to run the command and parses the exitcode. cmd - command to be run on the target returns: (stdout, stderr, returncode) """ if not self._check_keepalive(): raise ExecutionError("Keepalive no longer running") complete_cmd = ["ssh", "-x", *self.ssh_prefix, "-p", str(self.networkservice.port), "-l", self.networkservice.username, self.networkservice.address ] + cmd.split(" ") self.logger.debug("Sending command: %s", complete_cmd) if self.stderr_merge: stderr_pipe = subprocess.STDOUT else: stderr_pipe = subprocess.PIPE try: sub = subprocess.Popen( complete_cmd, stdin=subprocess.DEVNULL, stdout=subprocess.PIPE, stderr=stderr_pipe ) except: raise ExecutionError( f"error executing command: {complete_cmd}" ) stdout, stderr = sub.communicate(timeout=timeout) stdout = stdout.decode(codec, decodeerrors).split('\n') if stdout[-1] == '': stdout.pop() if stderr is None: stderr = [] else: stderr = stderr.decode(codec, decodeerrors).split('\n') stderr.pop() return (stdout, stderr, sub.returncode)
[docs] def interact(self, cmd=None): assert cmd is None or isinstance(cmd, list) if not self._check_keepalive(): raise ExecutionError("Keepalive no longer running") complete_cmd = ["ssh", "-x", *self.ssh_prefix, "-t", self.networkservice.address ] if cmd: complete_cmd += ["--", *cmd] self.logger.debug("Running command: %s", complete_cmd) sub = subprocess.Popen( complete_cmd, ) return sub.wait()
@contextlib.contextmanager def _forward(self, forward): cmd = ["ssh", *self.ssh_prefix, "-O", "forward", forward, self.networkservice.address ] self.logger.debug("Running command: %s", cmd) subprocess.run(cmd, check=True) try: yield finally: cmd = ["ssh", *self.ssh_prefix, "-O", "cancel", forward, self.networkservice.address ] self.logger.debug("Running command: %s", cmd) # Master socket may have been cleaned up already, so don't bother # the user with an error message subprocess.run(cmd, stderr=subprocess.DEVNULL)
[docs] @Driver.check_active @contextlib.contextmanager def forward_local_port(self, remoteport, localport=None): """Forward a local port to a remote port on the target A context manager that keeps a local port forwarded to a remote port as long as the context remains valid. A connection can be made to the returned port on localhost and it will be forwarded to the remote port on the target device usage: with ssh.forward_local_port(8080) as localport: # Use localhost:localport here to connect to port 8080 on the # target returns: localport """ if not self._check_keepalive(): raise ExecutionError("Keepalive no longer running") if localport is None: localport = get_free_port() forward = f"-L{localport:d}:localhost:{remoteport:d}" with self._forward(forward): yield localport
[docs] @Driver.check_active @contextlib.contextmanager def forward_remote_port(self, remoteport, localport): """Forward a remote port on the target to a local port A context manager that keeps a remote port forwarded to a local port as long as the context remains valid. A connection can be made to the remote on the target device will be forwarded to the returned local port on localhost usage: with ssh.forward_remote_port(8080, 8081) as localport: # Connections to port 8080 on the target will be redirected to # localhost:8081 """ if not self._check_keepalive(): raise ExecutionError("Keepalive no longer running") forward = f"-R{remoteport:d}:localhost:{localport:d}" with self._forward(forward): yield
[docs] @Driver.check_active @step(args=['src', 'dst']) def scp(self, *, src, dst): if not self._check_keepalive(): raise ExecutionError("Keepalive no longer running") if src.startswith(':') == dst.startswith(':'): raise ValueError("Either source or destination must be remote (start with :)") if src.startswith(':'): src = '_' + src if dst.startswith(':'): dst = '_' + dst complete_cmd = ["scp", "-F", "none", "-o", f"ControlPath={self.control.replace('%', '%%')}", src, dst, ] if self.explicit_sftp_mode and self._scp_supports_explicit_sftp_mode(): complete_cmd.insert(1, "-s") self.logger.info("Running command: %s", complete_cmd) sub = subprocess.Popen( complete_cmd, ) return sub.wait()
[docs] @Driver.check_active @step(args=['src', 'dst', 'extra']) def rsync(self, *, src, dst, extra=[]): if not self._check_keepalive(): raise ExecutionError("Keepalive no longer running") if src.startswith(':') == dst.startswith(':'): raise ValueError("Either source or destination must be remote (start with :)") if src.startswith(':'): src = '_' + src if dst.startswith(':'): dst = '_' + dst ssh_cmd = ["ssh", "-F", "none", "-o", f"ControlPath={self.control.replace('%', '%%')}", ] complete_cmd = ["rsync", "-v", f"--rsh={' '.join(ssh_cmd)}", "-rlpt", # --recursive --links --perms --times "--one-file-system", "--progress", *extra, src, dst, ] self.logger.info("Running command: %s", complete_cmd) sub = subprocess.Popen( complete_cmd, ) return sub.wait()
[docs] @Driver.check_active @step(args=['path', 'mountpoint']) def sshfs(self, *, path, mountpoint): if not self._check_keepalive(): raise ExecutionError("Keepalive no longer running") complete_cmd = ["sshfs", "-F", "none", "-f", "-o", f"ControlPath={self.control.replace('%', '%%')}", f":{path}", mountpoint, ] self.logger.debug("Running command: %s", complete_cmd) sub = subprocess.Popen( complete_cmd, ) try: sub.wait(1) raise ExecutionError( f"error executing command: {complete_cmd}" ) except subprocess.TimeoutExpired: # still running self.logger.info("Started SSHFS on %s. Press CTRL-C to stop.", mountpoint) sub.wait()
[docs] def get_status(self): """The SSHDriver is always connected, return 1""" return 1
def _scp_supports_explicit_sftp_mode(self): version = subprocess.run(["ssh", "-V"], capture_output=True, text=True) version = re.match(r"^OpenSSH_(\d+)\.(\d+)", version.stderr) major, minor = map(int, version.groups()) # OpenSSH >= 8.6 supports explicitly using the SFTP protocol via -s if major == 8 and minor >= 6: return True # OpenSSH >= 9.0 default to the SFTP protocol if major >= 9: return False raise Exception(f"OpenSSH version {major}.{minor} does not support explicit SFTP mode")
[docs] @Driver.check_active @step(args=['filename', 'remotepath']) def put(self, filename, remotepath=''): transfer_cmd = [ "scp", *self.ssh_prefix, "-P", str(self.networkservice.port), "-r", filename, f"{self.networkservice.username}@{self.networkservice.address}:{remotepath}" ] if self.explicit_sftp_mode and self._scp_supports_explicit_sftp_mode(): transfer_cmd.insert(1, "-s") try: sub = subprocess.call( transfer_cmd ) #, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) except: raise ExecutionError( f"error executing command: {transfer_cmd}" ) if sub != 0: raise ExecutionError( f"error executing command: {transfer_cmd}" )
[docs] @Driver.check_active @step(args=['filename', 'destination']) def get(self, filename, destination="."): transfer_cmd = [ "scp", *self.ssh_prefix, "-P", str(self.networkservice.port), "-r", f"{self.networkservice.username}@{self.networkservice.address}:{filename}", destination ] if self.explicit_sftp_mode and self._scp_supports_explicit_sftp_mode(): transfer_cmd.insert(1, "-s") try: sub = subprocess.call( transfer_cmd ) #, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) except: raise ExecutionError( f"error executing command: {transfer_cmd}" ) if sub != 0: raise ExecutionError( f"error executing command: {transfer_cmd}" )
def _cleanup_own_master(self): """Exit the controlmaster and delete the tmpdir""" complete_cmd = f"ssh -x -o ControlPath={self.control.replace('%', '%%')} -O exit -p {self.networkservice.port} -l {self.networkservice.username} {self.networkservice.address}".split(' ') # pylint: disable=line-too-long res = subprocess.call( complete_cmd, stdin=subprocess.DEVNULL, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL ) if res != 0: self.logger.info("Socket already closed") self.process.communicate() shutil.rmtree(self.tmpdir, ignore_errors=True) def _start_keepalive(self): """Starts a keepalive connection via the own or external master.""" args = ["ssh", *self.ssh_prefix, self.networkservice.address, "cat"] assert self._keepalive is None self._keepalive = subprocess.Popen( args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) self.logger.debug('Started keepalive for %s', self.networkservice.address) def _check_keepalive(self): return self._keepalive.poll() is None def _stop_keepalive(self): assert self._keepalive is not None self.logger.debug('Stopping keepalive for %s', self.networkservice.address) try: self._keepalive.communicate(timeout=60) except subprocess.TimeoutExpired: self._keepalive.kill() try: self._keepalive.wait(timeout=60) finally: self._keepalive = None