1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3# 4# Project ___| | | | _ \| | 5# / __| | | | |_) | | 6# | (__| |_| | _ <| |___ 7# \___|\___/|_| \_\_____| 8# 9# Copyright (C) Daniel Stenberg, <daniel@haxx.se>, et al. 10# 11# This software is licensed as described in the file COPYING, which 12# you should have received as part of this distribution. The terms 13# are also available at https://curl.se/docs/copyright.html. 14# 15# You may opt to use, copy, modify, merge, publish, distribute and/or sell 16# copies of the Software, and permit persons to whom the Software is 17# furnished to do so, under the terms of the COPYING file. 18# 19# This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY 20# KIND, either express or implied. 21# 22# SPDX-License-Identifier: curl 23# 24"""Server for testing SMB""" 25 26from __future__ import (absolute_import, division, print_function, 27 unicode_literals) 28 29import argparse 30import logging 31import os 32import signal 33import sys 34import tempfile 35import threading 36 37# Import our curl test data helper 38from util import ClosingFileHandler, TestData 39 40if sys.version_info.major >= 3: 41 import configparser 42else: 43 import ConfigParser as configparser 44 45# impacket needs to be installed in the Python environment 46try: 47 import impacket 48except ImportError: 49 sys.stderr.write( 50 'Warning: Python package impacket is required for smb testing; ' 51 'use pip or your package manager to install it\n') 52 sys.exit(1) 53from impacket import smb as imp_smb 54from impacket import smbserver as imp_smbserver 55from impacket.nt_errors import (STATUS_ACCESS_DENIED, STATUS_NO_SUCH_FILE, 56 STATUS_SUCCESS) 57 58log = logging.getLogger(__name__) 59SERVER_MAGIC = "SERVER_MAGIC" 60TESTS_MAGIC = "TESTS_MAGIC" 61VERIFIED_REQ = "verifiedserver" 62VERIFIED_RSP = "WE ROOLZ: {pid}\n" 63 64 65class ShutdownHandler(threading.Thread): 66 """Cleanly shut down the SMB server 67 68 This can only be done from another thread while the server is in 69 serve_forever(), so a thread is spawned here that waits for a shutdown 70 signal before doing its thing. Use in a with statement around the 71 serve_forever() call. 72 """ 73 74 def __init__(self, server): 75 super(ShutdownHandler, self).__init__() 76 self.server = server 77 self.shutdown_event = threading.Event() 78 79 def __enter__(self): 80 self.start() 81 signal.signal(signal.SIGINT, self._sighandler) 82 signal.signal(signal.SIGTERM, self._sighandler) 83 84 def __exit__(self, *_): 85 # Call for shutdown just in case it wasn't done already 86 self.shutdown_event.set() 87 # Wait for thread, and therefore also the server, to finish 88 self.join() 89 # Uninstall our signal handlers 90 signal.signal(signal.SIGINT, signal.SIG_DFL) 91 signal.signal(signal.SIGTERM, signal.SIG_DFL) 92 # Delete any temporary files created by the server during its run 93 log.info("Deleting %d temporary file(s)", len(self.server.tmpfiles)) 94 for f in self.server.tmpfiles: 95 os.unlink(f) 96 97 def _sighandler(self, _signum, _frame): 98 # Wake up the cleanup task 99 self.shutdown_event.set() 100 101 def run(self): 102 # Wait for shutdown signal 103 self.shutdown_event.wait() 104 # Notify the server to shut down 105 self.server.shutdown() 106 107 108def smbserver(options): 109 """Start up a TCP SMB server that serves forever 110 111 """ 112 if options.pidfile: 113 pid = os.getpid() 114 # see tests/server/util.c function write_pidfile 115 if os.name == "nt": 116 pid += 65536 117 with open(options.pidfile, "w") as f: 118 f.write(str(pid)) 119 120 # Here we write a mini config for the server 121 smb_config = configparser.ConfigParser() 122 smb_config.add_section("global") 123 smb_config.set("global", "server_name", "SERVICE") 124 smb_config.set("global", "server_os", "UNIX") 125 smb_config.set("global", "server_domain", "WORKGROUP") 126 smb_config.set("global", "log_file", "None") 127 smb_config.set("global", "credentials_file", "") 128 129 # We need a share which allows us to test that the server is running 130 smb_config.add_section("SERVER") 131 smb_config.set("SERVER", "comment", "server function") 132 smb_config.set("SERVER", "read only", "yes") 133 smb_config.set("SERVER", "share type", "0") 134 smb_config.set("SERVER", "path", SERVER_MAGIC) 135 136 # Have a share for tests. These files will be autogenerated from the 137 # test input. 138 smb_config.add_section("TESTS") 139 smb_config.set("TESTS", "comment", "tests") 140 smb_config.set("TESTS", "read only", "yes") 141 smb_config.set("TESTS", "share type", "0") 142 smb_config.set("TESTS", "path", TESTS_MAGIC) 143 144 if not options.srcdir or not os.path.isdir(options.srcdir): 145 raise ScriptException("--srcdir is mandatory") 146 147 test_data_dir = os.path.join(options.srcdir, "data") 148 149 smb_server = TestSmbServer((options.host, options.port), 150 config_parser=smb_config, 151 test_data_directory=test_data_dir) 152 log.info("[SMB] setting up SMB server on port %s", options.port) 153 smb_server.processConfigFile() 154 155 # Start a thread that cleanly shuts down the server on a signal 156 with ShutdownHandler(smb_server): 157 # This will block until smb_server.shutdown() is called 158 smb_server.serve_forever() 159 160 return 0 161 162 163class TestSmbServer(imp_smbserver.SMBSERVER): 164 """ 165 Test server for SMB which subclasses the impacket SMBSERVER and provides 166 test functionality. 167 """ 168 169 def __init__(self, 170 address, 171 config_parser=None, 172 test_data_directory=None): 173 imp_smbserver.SMBSERVER.__init__(self, 174 address, 175 config_parser=config_parser) 176 self.tmpfiles = [] 177 178 # Set up a test data object so we can get test data later. 179 self.ctd = TestData(test_data_directory) 180 181 # Override smbComNtCreateAndX so we can pretend to have files which 182 # don't exist. 183 self.hookSmbCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX, 184 self.create_and_x) 185 186 def create_and_x(self, conn_id, smb_server, smb_command, recv_packet): 187 """ 188 Our version of smbComNtCreateAndX looks for special test files and 189 fools the rest of the framework into opening them as if they were 190 normal files. 191 """ 192 conn_data = smb_server.getConnectionData(conn_id) 193 194 # Wrap processing in a try block which allows us to throw SmbException 195 # to control the flow. 196 try: 197 ncax_parms = imp_smb.SMBNtCreateAndX_Parameters( 198 smb_command["Parameters"]) 199 200 path = self.get_share_path(conn_data, 201 ncax_parms["RootFid"], 202 recv_packet["Tid"]) 203 log.info("[SMB] Requested share path: %s", path) 204 205 disposition = ncax_parms["Disposition"] 206 log.debug("[SMB] Requested disposition: %s", disposition) 207 208 # Currently we only support reading files. 209 if disposition != imp_smb.FILE_OPEN: 210 raise SmbException(STATUS_ACCESS_DENIED, 211 "Only support reading files") 212 213 # Check to see if the path we were given is actually a 214 # magic path which needs generating on the fly. 215 if path not in [SERVER_MAGIC, TESTS_MAGIC]: 216 # Pass the command onto the original handler. 217 return imp_smbserver.SMBCommands.smbComNtCreateAndX(conn_id, 218 smb_server, 219 smb_command, 220 recv_packet) 221 222 flags2 = recv_packet["Flags2"] 223 ncax_data = imp_smb.SMBNtCreateAndX_Data(flags=flags2, 224 data=smb_command[ 225 "Data"]) 226 requested_file = imp_smbserver.decodeSMBString( 227 flags2, 228 ncax_data["FileName"]) 229 log.debug("[SMB] User requested file '%s'", requested_file) 230 231 if path == SERVER_MAGIC: 232 fid, full_path = self.get_server_path(requested_file) 233 else: 234 assert (path == TESTS_MAGIC) 235 fid, full_path = self.get_test_path(requested_file) 236 237 self.tmpfiles.append(full_path) 238 239 resp_parms = imp_smb.SMBNtCreateAndXResponse_Parameters() 240 resp_data = "" 241 242 # Simple way to generate a fid 243 if len(conn_data["OpenedFiles"]) == 0: 244 fakefid = 1 245 else: 246 fakefid = conn_data["OpenedFiles"].keys()[-1] + 1 247 resp_parms["Fid"] = fakefid 248 resp_parms["CreateAction"] = disposition 249 250 if os.path.isdir(path): 251 resp_parms[ 252 "FileAttributes"] = imp_smb.SMB_FILE_ATTRIBUTE_DIRECTORY 253 resp_parms["IsDirectory"] = 1 254 else: 255 resp_parms["IsDirectory"] = 0 256 resp_parms["FileAttributes"] = ncax_parms["FileAttributes"] 257 258 # Get this file's information 259 resp_info, error_code = imp_smbserver.queryPathInformation( 260 os.path.dirname(full_path), os.path.basename(full_path), 261 level=imp_smb.SMB_QUERY_FILE_ALL_INFO) 262 263 if error_code != STATUS_SUCCESS: 264 raise SmbException(error_code, "Failed to query path info") 265 266 resp_parms["CreateTime"] = resp_info["CreationTime"] 267 resp_parms["LastAccessTime"] = resp_info[ 268 "LastAccessTime"] 269 resp_parms["LastWriteTime"] = resp_info["LastWriteTime"] 270 resp_parms["LastChangeTime"] = resp_info[ 271 "LastChangeTime"] 272 resp_parms["FileAttributes"] = resp_info[ 273 "ExtFileAttributes"] 274 resp_parms["AllocationSize"] = resp_info[ 275 "AllocationSize"] 276 resp_parms["EndOfFile"] = resp_info["EndOfFile"] 277 278 # Let's store the fid for the connection 279 # smbServer.log("Create file %s, mode:0x%x" % (pathName, mode)) 280 conn_data["OpenedFiles"][fakefid] = {} 281 conn_data["OpenedFiles"][fakefid]["FileHandle"] = fid 282 conn_data["OpenedFiles"][fakefid]["FileName"] = path 283 conn_data["OpenedFiles"][fakefid]["DeleteOnClose"] = False 284 285 except SmbException as s: 286 log.debug("[SMB] SmbException hit: %s", s) 287 error_code = s.error_code 288 resp_parms = "" 289 resp_data = "" 290 291 resp_cmd = imp_smb.SMBCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX) 292 resp_cmd["Parameters"] = resp_parms 293 resp_cmd["Data"] = resp_data 294 smb_server.setConnectionData(conn_id, conn_data) 295 296 return [resp_cmd], None, error_code 297 298 def get_share_path(self, conn_data, root_fid, tid): 299 conn_shares = conn_data["ConnectedShares"] 300 301 if tid in conn_shares: 302 if root_fid > 0: 303 # If we have a rootFid, the path is relative to that fid 304 path = conn_data["OpenedFiles"][root_fid]["FileName"] 305 log.debug("RootFid present %s!" % path) 306 else: 307 if "path" in conn_shares[tid]: 308 path = conn_shares[tid]["path"] 309 else: 310 raise SmbException(STATUS_ACCESS_DENIED, 311 "Connection share had no path") 312 else: 313 raise SmbException(imp_smbserver.STATUS_SMB_BAD_TID, 314 "TID was invalid") 315 316 return path 317 318 def get_server_path(self, requested_filename): 319 log.debug("[SMB] Get server path '%s'", requested_filename) 320 321 if requested_filename not in [VERIFIED_REQ]: 322 raise SmbException(STATUS_NO_SUCH_FILE, "Couldn't find the file") 323 324 fid, filename = tempfile.mkstemp() 325 log.debug("[SMB] Created %s (%d) for storing '%s'", 326 filename, fid, requested_filename) 327 328 contents = "" 329 330 if requested_filename == VERIFIED_REQ: 331 log.debug("[SMB] Verifying server is alive") 332 pid = os.getpid() 333 # see tests/server/util.c function write_pidfile 334 if os.name == "nt": 335 pid += 65536 336 contents = VERIFIED_RSP.format(pid=pid).encode('utf-8') 337 338 self.write_to_fid(fid, contents) 339 return fid, filename 340 341 def write_to_fid(self, fid, contents): 342 # Write the contents to file descriptor 343 os.write(fid, contents) 344 os.fsync(fid) 345 346 # Rewind the file to the beginning so a read gets us the contents 347 os.lseek(fid, 0, os.SEEK_SET) 348 349 def get_test_path(self, requested_filename): 350 log.info("[SMB] Get reply data from 'test%s'", requested_filename) 351 352 fid, filename = tempfile.mkstemp() 353 log.debug("[SMB] Created %s (%d) for storing test '%s'", 354 filename, fid, requested_filename) 355 356 try: 357 contents = self.ctd.get_test_data(requested_filename).encode('utf-8') 358 self.write_to_fid(fid, contents) 359 return fid, filename 360 361 except Exception: 362 log.exception("Failed to make test file") 363 raise SmbException(STATUS_NO_SUCH_FILE, "Failed to make test file") 364 365 366class SmbException(Exception): 367 def __init__(self, error_code, error_message): 368 super(SmbException, self).__init__(error_message) 369 self.error_code = error_code 370 371 372class ScriptRC(object): 373 """Enum for script return codes""" 374 SUCCESS = 0 375 FAILURE = 1 376 EXCEPTION = 2 377 378 379class ScriptException(Exception): 380 pass 381 382 383def get_options(): 384 parser = argparse.ArgumentParser() 385 386 parser.add_argument("--port", action="store", default=9017, 387 type=int, help="port to listen on") 388 parser.add_argument("--host", action="store", default="127.0.0.1", 389 help="host to listen on") 390 parser.add_argument("--verbose", action="store", type=int, default=0, 391 help="verbose output") 392 parser.add_argument("--pidfile", action="store", 393 help="file name for the PID") 394 parser.add_argument("--logfile", action="store", 395 help="file name for the log") 396 parser.add_argument("--srcdir", action="store", help="test directory") 397 parser.add_argument("--id", action="store", help="server ID") 398 parser.add_argument("--ipv4", action="store_true", default=0, 399 help="IPv4 flag") 400 401 return parser.parse_args() 402 403 404def setup_logging(options): 405 """ 406 Set up logging from the command line options 407 """ 408 root_logger = logging.getLogger() 409 add_stdout = False 410 411 formatter = logging.Formatter("%(asctime)s %(levelname)-5.5s %(message)s") 412 413 # Write out to a logfile 414 if options.logfile: 415 handler = ClosingFileHandler(options.logfile) 416 handler.setFormatter(formatter) 417 handler.setLevel(logging.DEBUG) 418 root_logger.addHandler(handler) 419 else: 420 # The logfile wasn't specified. Add a stdout logger. 421 add_stdout = True 422 423 if options.verbose: 424 # Add a stdout logger as well in verbose mode 425 root_logger.setLevel(logging.DEBUG) 426 add_stdout = True 427 else: 428 root_logger.setLevel(logging.WARNING) 429 430 if add_stdout: 431 stdout_handler = logging.StreamHandler(sys.stdout) 432 stdout_handler.setFormatter(formatter) 433 stdout_handler.setLevel(logging.DEBUG) 434 root_logger.addHandler(stdout_handler) 435 436 437if __name__ == '__main__': 438 # Get the options from the user. 439 options = get_options() 440 441 # Setup logging using the user options 442 setup_logging(options) 443 444 # Run main script. 445 try: 446 rc = smbserver(options) 447 except Exception as e: 448 log.exception(e) 449 rc = ScriptRC.EXCEPTION 450 451 if options.pidfile and os.path.isfile(options.pidfile): 452 os.unlink(options.pidfile) 453 454 log.info("[SMB] Returning %d", rc) 455 sys.exit(rc) 456