#!/usr/bin/env python3 # -*- coding: utf-8 -*- # # Project ___| | | | _ \| | # / __| | | | |_) | | # | (__| |_| | _ <| |___ # \___|\___/|_| \_\_____| # # Copyright (C) Daniel Stenberg, , et al. # # This software is licensed as described in the file COPYING, which # you should have received as part of this distribution. The terms # are also available at https://curl.se/docs/copyright.html. # # You may opt to use, copy, modify, merge, publish, distribute and/or sell # copies of the Software, and permit persons to whom the Software is # furnished to do so, under the terms of the COPYING file. # # This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY # KIND, either express or implied. # # SPDX-License-Identifier: curl # """Server for testing SMB.""" from __future__ import (absolute_import, division, print_function, unicode_literals) import argparse import logging import os import signal import sys import tempfile import threading # Import our curl test data helper from util import ClosingFileHandler, TestData if sys.version_info.major >= 3: import configparser else: import ConfigParser as configparser # impacket needs to be installed in the Python environment try: import impacket # noqa: F401 except ImportError: sys.stderr.write( 'Warning: Python package impacket is required for smb testing; ' 'use pip or your package manager to install it\n') sys.exit(1) from impacket import smb as imp_smb from impacket import smbserver as imp_smbserver from impacket.nt_errors import (STATUS_ACCESS_DENIED, STATUS_NO_SUCH_FILE, STATUS_SUCCESS) log = logging.getLogger(__name__) SERVER_MAGIC = "SERVER_MAGIC" TESTS_MAGIC = "TESTS_MAGIC" VERIFIED_REQ = "verifiedserver" VERIFIED_RSP = "WE ROOLZ: {pid}\n" class ShutdownHandler(threading.Thread): """ Cleanly shut down the SMB server. This can only be done from another thread while the server is in serve_forever(), so a thread is spawned here that waits for a shutdown signal before doing its thing. Use in a with statement around the serve_forever() call. """ def __init__(self, server): super(ShutdownHandler, self).__init__() self.server = server self.shutdown_event = threading.Event() def __enter__(self): self.start() signal.signal(signal.SIGINT, self._sighandler) signal.signal(signal.SIGTERM, self._sighandler) def __exit__(self, *_): # Call for shutdown just in case it wasn't done already self.shutdown_event.set() # Wait for thread, and therefore also the server, to finish self.join() # Uninstall our signal handlers signal.signal(signal.SIGINT, signal.SIG_DFL) signal.signal(signal.SIGTERM, signal.SIG_DFL) # Delete any temporary files created by the server during its run log.info("Deleting %d temporary file(s)", len(self.server.tmpfiles)) for f in self.server.tmpfiles: os.unlink(f) def _sighandler(self, _signum, _frame): # Wake up the cleanup task self.shutdown_event.set() def run(self): # Wait for shutdown signal self.shutdown_event.wait() # Notify the server to shut down self.server.shutdown() def smbserver(options): """Start up a TCP SMB server that serves forever.""" if options.pidfile: pid = os.getpid() # see tests/server/util.c function write_pidfile if os.name == "nt": pid += 65536 with open(options.pidfile, "w") as f: f.write(str(pid)) # Here we write a mini config for the server smb_config = configparser.ConfigParser() smb_config.add_section("global") smb_config.set("global", "server_name", "SERVICE") smb_config.set("global", "server_os", "UNIX") smb_config.set("global", "server_domain", "WORKGROUP") smb_config.set("global", "log_file", "None") smb_config.set("global", "credentials_file", "") # We need a share which allows us to test that the server is running smb_config.add_section("SERVER") smb_config.set("SERVER", "comment", "server function") smb_config.set("SERVER", "read only", "yes") smb_config.set("SERVER", "share type", "0") smb_config.set("SERVER", "path", SERVER_MAGIC) # Have a share for tests. These files will be autogenerated from the # test input. smb_config.add_section("TESTS") smb_config.set("TESTS", "comment", "tests") smb_config.set("TESTS", "read only", "yes") smb_config.set("TESTS", "share type", "0") smb_config.set("TESTS", "path", TESTS_MAGIC) if not options.srcdir or not os.path.isdir(options.srcdir): raise ScriptError("--srcdir is mandatory") test_data_dir = os.path.join(options.srcdir, "data") smb_server = TestSmbServer((options.host, options.port), config_parser=smb_config, test_data_directory=test_data_dir) log.info("[SMB] setting up SMB server on port %s", options.port) smb_server.processConfigFile() # Start a thread that cleanly shuts down the server on a signal with ShutdownHandler(smb_server): # This will block until smb_server.shutdown() is called smb_server.serve_forever() return 0 class TestSmbServer(imp_smbserver.SMBSERVER): """ Test server for SMB which subclasses the impacket SMBSERVER and provides test functionality. """ def __init__(self, address, config_parser=None, test_data_directory=None): imp_smbserver.SMBSERVER.__init__(self, address, config_parser=config_parser) self.tmpfiles = [] # Set up a test data object so we can get test data later. self.ctd = TestData(test_data_directory) # Override smbComNtCreateAndX so we can pretend to have files which # don't exist. self.hookSmbCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX, self.create_and_x) def create_and_x(self, conn_id, smb_server, smb_command, recv_packet): """ Our version of smbComNtCreateAndX looks for special test files and fools the rest of the framework into opening them as if they were normal files. """ conn_data = smb_server.getConnectionData(conn_id) # Wrap processing in a try block which allows us to throw SmbError # to control the flow. try: ncax_parms = imp_smb.SMBNtCreateAndX_Parameters( smb_command["Parameters"]) path = self.get_share_path(conn_data, ncax_parms["RootFid"], recv_packet["Tid"]) log.info("[SMB] Requested share path: %s", path) disposition = ncax_parms["Disposition"] log.debug("[SMB] Requested disposition: %s", disposition) # Currently we only support reading files. if disposition != imp_smb.FILE_OPEN: raise SmbError(STATUS_ACCESS_DENIED, "Only support reading files") # Check to see if the path we were given is actually a # magic path which needs generating on the fly. if path not in [SERVER_MAGIC, TESTS_MAGIC]: # Pass the command onto the original handler. return imp_smbserver.SMBCommands.smbComNtCreateAndX(conn_id, smb_server, smb_command, recv_packet) flags2 = recv_packet["Flags2"] ncax_data = imp_smb.SMBNtCreateAndX_Data(flags=flags2, data=smb_command[ "Data"]) requested_file = imp_smbserver.decodeSMBString( flags2, ncax_data["FileName"]) log.debug("[SMB] User requested file '%s'", requested_file) if path == SERVER_MAGIC: fid, full_path = self.get_server_path(requested_file) else: assert path == TESTS_MAGIC fid, full_path = self.get_test_path(requested_file) self.tmpfiles.append(full_path) resp_parms = imp_smb.SMBNtCreateAndXResponse_Parameters() resp_data = "" # Simple way to generate a fid if len(conn_data["OpenedFiles"]) == 0: fakefid = 1 else: fakefid = conn_data["OpenedFiles"].keys()[-1] + 1 resp_parms["Fid"] = fakefid resp_parms["CreateAction"] = disposition if os.path.isdir(path): resp_parms[ "FileAttributes"] = imp_smb.SMB_FILE_ATTRIBUTE_DIRECTORY resp_parms["IsDirectory"] = 1 else: resp_parms["IsDirectory"] = 0 resp_parms["FileAttributes"] = ncax_parms["FileAttributes"] # Get this file's information resp_info, error_code = imp_smbserver.queryPathInformation( os.path.dirname(full_path), os.path.basename(full_path), level=imp_smb.SMB_QUERY_FILE_ALL_INFO) if error_code != STATUS_SUCCESS: raise SmbError(error_code, "Failed to query path info") resp_parms["CreateTime"] = resp_info["CreationTime"] resp_parms["LastAccessTime"] = resp_info[ "LastAccessTime"] resp_parms["LastWriteTime"] = resp_info["LastWriteTime"] resp_parms["LastChangeTime"] = resp_info[ "LastChangeTime"] resp_parms["FileAttributes"] = resp_info[ "ExtFileAttributes"] resp_parms["AllocationSize"] = resp_info[ "AllocationSize"] resp_parms["EndOfFile"] = resp_info["EndOfFile"] # Let's store the fid for the connection # smbServer.log("Create file %s, mode:0x%x" % (pathName, mode)) conn_data["OpenedFiles"][fakefid] = {} conn_data["OpenedFiles"][fakefid]["FileHandle"] = fid conn_data["OpenedFiles"][fakefid]["FileName"] = path conn_data["OpenedFiles"][fakefid]["DeleteOnClose"] = False except SmbError as s: log.debug("[SMB] SmbError hit: %s", s) error_code = s.error_code resp_parms = "" resp_data = "" resp_cmd = imp_smb.SMBCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX) resp_cmd["Parameters"] = resp_parms resp_cmd["Data"] = resp_data smb_server.setConnectionData(conn_id, conn_data) return [resp_cmd], None, error_code def get_share_path(self, conn_data, root_fid, tid): conn_shares = conn_data["ConnectedShares"] if tid in conn_shares: if root_fid > 0: # If we have a rootFid, the path is relative to that fid path = conn_data["OpenedFiles"][root_fid]["FileName"] log.debug("RootFid present %s!" % path) else: if "path" in conn_shares[tid]: path = conn_shares[tid]["path"] else: raise SmbError(STATUS_ACCESS_DENIED, "Connection share had no path") else: raise SmbError(imp_smbserver.STATUS_SMB_BAD_TID, "TID was invalid") return path def get_server_path(self, requested_filename): log.debug("[SMB] Get server path '%s'", requested_filename) if requested_filename not in [VERIFIED_REQ]: raise SmbError(STATUS_NO_SUCH_FILE, "Couldn't find the file") fid, filename = tempfile.mkstemp() log.debug("[SMB] Created %s (%d) for storing '%s'", filename, fid, requested_filename) contents = "" if requested_filename == VERIFIED_REQ: log.debug("[SMB] Verifying server is alive") pid = os.getpid() # see tests/server/util.c function write_pidfile if os.name == "nt": pid += 65536 contents = VERIFIED_RSP.format(pid=pid).encode('utf-8') self.write_to_fid(fid, contents) return fid, filename def write_to_fid(self, fid, contents): # Write the contents to file descriptor os.write(fid, contents) os.fsync(fid) # Rewind the file to the beginning so a read gets us the contents os.lseek(fid, 0, os.SEEK_SET) def get_test_path(self, requested_filename): log.info("[SMB] Get reply data from 'test%s'", requested_filename) fid, filename = tempfile.mkstemp() log.debug("[SMB] Created %s (%d) for storing test '%s'", filename, fid, requested_filename) try: contents = self.ctd.get_test_data(requested_filename).encode('utf-8') self.write_to_fid(fid, contents) return fid, filename except Exception: log.exception("Failed to make test file") raise SmbError(STATUS_NO_SUCH_FILE, "Failed to make test file") class SmbError(Exception): def __init__(self, error_code, error_message): super(SmbError, self).__init__(error_message) self.error_code = error_code class ScriptRC(object): """Enum for script return codes.""" SUCCESS = 0 FAILURE = 1 EXCEPTION = 2 class ScriptError(Exception): pass def get_options(): parser = argparse.ArgumentParser() parser.add_argument("--port", action="store", default=9017, type=int, help="port to listen on") parser.add_argument("--host", action="store", default="127.0.0.1", help="host to listen on") parser.add_argument("--verbose", action="store", type=int, default=0, help="verbose output") parser.add_argument("--pidfile", action="store", help="file name for the PID") parser.add_argument("--logfile", action="store", help="file name for the log") parser.add_argument("--srcdir", action="store", help="test directory") parser.add_argument("--id", action="store", help="server ID") parser.add_argument("--ipv4", action="store_true", default=0, help="IPv4 flag") return parser.parse_args() def setup_logging(options): """Set up logging from the command line options.""" root_logger = logging.getLogger() add_stdout = False formatter = logging.Formatter("%(asctime)s %(levelname)-5.5s %(message)s") # Write out to a logfile if options.logfile: handler = ClosingFileHandler(options.logfile) handler.setFormatter(formatter) handler.setLevel(logging.DEBUG) root_logger.addHandler(handler) else: # The logfile wasn't specified. Add a stdout logger. add_stdout = True if options.verbose: # Add a stdout logger as well in verbose mode root_logger.setLevel(logging.DEBUG) add_stdout = True else: root_logger.setLevel(logging.WARNING) if add_stdout: stdout_handler = logging.StreamHandler(sys.stdout) stdout_handler.setFormatter(formatter) stdout_handler.setLevel(logging.DEBUG) root_logger.addHandler(stdout_handler) if __name__ == '__main__': # Get the options from the user. options = get_options() # Setup logging using the user options setup_logging(options) # Run main script. try: rc = smbserver(options) except Exception: log.exception('Error in SMB server') rc = ScriptRC.EXCEPTION if options.pidfile and os.path.isfile(options.pidfile): os.unlink(options.pidfile) log.info("[SMB] Returning %d", rc) sys.exit(rc)