1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3# 4# Project ___| | | | _ \| | 5# / __| | | | |_) | | 6# | (__| |_| | _ <| |___ 7# \___|\___/|_| \_\_____| 8# 9# Copyright (C) Daniel Stenberg, <daniel@haxx.se>, et al. 10# 11# This software is licensed as described in the file COPYING, which 12# you should have received as part of this distribution. The terms 13# are also available at https://curl.se/docs/copyright.html. 14# 15# You may opt to use, copy, modify, merge, publish, distribute and/or sell 16# copies of the Software, and permit persons to whom the Software is 17# furnished to do so, under the terms of the COPYING file. 18# 19# This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY 20# KIND, either express or implied. 21# 22# SPDX-License-Identifier: curl 23# 24"""Server for testing SMB.""" 25 26from __future__ import (absolute_import, division, print_function, 27 unicode_literals) 28 29import argparse 30import logging 31import os 32import signal 33import sys 34import tempfile 35import threading 36 37# Import our curl test data helper 38from util import ClosingFileHandler, TestData 39 40if sys.version_info.major >= 3: 41 import configparser 42else: 43 import ConfigParser as configparser 44 45# impacket needs to be installed in the Python environment 46try: 47 import impacket # noqa: F401 48except ImportError: 49 sys.stderr.write( 50 'Warning: Python package impacket is required for smb testing; ' 51 'use pip or your package manager to install it\n') 52 sys.exit(1) 53from impacket import smb as imp_smb 54from impacket import smbserver as imp_smbserver 55from impacket.nt_errors import (STATUS_ACCESS_DENIED, STATUS_NO_SUCH_FILE, 56 STATUS_SUCCESS) 57 58log = logging.getLogger(__name__) 59SERVER_MAGIC = "SERVER_MAGIC" 60TESTS_MAGIC = "TESTS_MAGIC" 61VERIFIED_REQ = "verifiedserver" 62VERIFIED_RSP = "WE ROOLZ: {pid}\n" 63 64 65class ShutdownHandler(threading.Thread): 66 """ 67 Cleanly shut down the SMB server. 68 69 This can only be done from another thread while the server is in 70 serve_forever(), so a thread is spawned here that waits for a shutdown 71 signal before doing its thing. Use in a with statement around the 72 serve_forever() call. 73 """ 74 75 def __init__(self, server): 76 super(ShutdownHandler, self).__init__() 77 self.server = server 78 self.shutdown_event = threading.Event() 79 80 def __enter__(self): 81 self.start() 82 signal.signal(signal.SIGINT, self._sighandler) 83 signal.signal(signal.SIGTERM, self._sighandler) 84 85 def __exit__(self, *_): 86 # Call for shutdown just in case it wasn't done already 87 self.shutdown_event.set() 88 # Wait for thread, and therefore also the server, to finish 89 self.join() 90 # Uninstall our signal handlers 91 signal.signal(signal.SIGINT, signal.SIG_DFL) 92 signal.signal(signal.SIGTERM, signal.SIG_DFL) 93 # Delete any temporary files created by the server during its run 94 log.info("Deleting %d temporary file(s)", len(self.server.tmpfiles)) 95 for f in self.server.tmpfiles: 96 os.unlink(f) 97 98 def _sighandler(self, _signum, _frame): 99 # Wake up the cleanup task 100 self.shutdown_event.set() 101 102 def run(self): 103 # Wait for shutdown signal 104 self.shutdown_event.wait() 105 # Notify the server to shut down 106 self.server.shutdown() 107 108 109def smbserver(options): 110 """Start up a TCP SMB server that serves forever.""" 111 if options.pidfile: 112 pid = os.getpid() 113 # see tests/server/util.c function write_pidfile 114 if os.name == "nt": 115 pid += 65536 116 with open(options.pidfile, "w") as f: 117 f.write(str(pid)) 118 119 # Here we write a mini config for the server 120 smb_config = configparser.ConfigParser() 121 smb_config.add_section("global") 122 smb_config.set("global", "server_name", "SERVICE") 123 smb_config.set("global", "server_os", "UNIX") 124 smb_config.set("global", "server_domain", "WORKGROUP") 125 smb_config.set("global", "log_file", "None") 126 smb_config.set("global", "credentials_file", "") 127 128 # We need a share which allows us to test that the server is running 129 smb_config.add_section("SERVER") 130 smb_config.set("SERVER", "comment", "server function") 131 smb_config.set("SERVER", "read only", "yes") 132 smb_config.set("SERVER", "share type", "0") 133 smb_config.set("SERVER", "path", SERVER_MAGIC) 134 135 # Have a share for tests. These files will be autogenerated from the 136 # test input. 137 smb_config.add_section("TESTS") 138 smb_config.set("TESTS", "comment", "tests") 139 smb_config.set("TESTS", "read only", "yes") 140 smb_config.set("TESTS", "share type", "0") 141 smb_config.set("TESTS", "path", TESTS_MAGIC) 142 143 if not options.srcdir or not os.path.isdir(options.srcdir): 144 raise ScriptError("--srcdir is mandatory") 145 146 test_data_dir = os.path.join(options.srcdir, "data") 147 148 smb_server = TestSmbServer((options.host, options.port), 149 config_parser=smb_config, 150 test_data_directory=test_data_dir) 151 log.info("[SMB] setting up SMB server on port %s", options.port) 152 smb_server.processConfigFile() 153 154 # Start a thread that cleanly shuts down the server on a signal 155 with ShutdownHandler(smb_server): 156 # This will block until smb_server.shutdown() is called 157 smb_server.serve_forever() 158 159 return 0 160 161 162class TestSmbServer(imp_smbserver.SMBSERVER): 163 """ 164 Test server for SMB which subclasses the impacket SMBSERVER and provides 165 test functionality. 166 """ 167 168 def __init__(self, 169 address, 170 config_parser=None, 171 test_data_directory=None): 172 imp_smbserver.SMBSERVER.__init__(self, 173 address, 174 config_parser=config_parser) 175 self.tmpfiles = [] 176 177 # Set up a test data object so we can get test data later. 178 self.ctd = TestData(test_data_directory) 179 180 # Override smbComNtCreateAndX so we can pretend to have files which 181 # don't exist. 182 self.hookSmbCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX, 183 self.create_and_x) 184 185 def create_and_x(self, conn_id, smb_server, smb_command, recv_packet): 186 """ 187 Our version of smbComNtCreateAndX looks for special test files and 188 fools the rest of the framework into opening them as if they were 189 normal files. 190 """ 191 conn_data = smb_server.getConnectionData(conn_id) 192 193 # Wrap processing in a try block which allows us to throw SmbError 194 # to control the flow. 195 try: 196 ncax_parms = imp_smb.SMBNtCreateAndX_Parameters( 197 smb_command["Parameters"]) 198 199 path = self.get_share_path(conn_data, 200 ncax_parms["RootFid"], 201 recv_packet["Tid"]) 202 log.info("[SMB] Requested share path: %s", path) 203 204 disposition = ncax_parms["Disposition"] 205 log.debug("[SMB] Requested disposition: %s", disposition) 206 207 # Currently we only support reading files. 208 if disposition != imp_smb.FILE_OPEN: 209 raise SmbError(STATUS_ACCESS_DENIED, 210 "Only support reading files") 211 212 # Check to see if the path we were given is actually a 213 # magic path which needs generating on the fly. 214 if path not in [SERVER_MAGIC, TESTS_MAGIC]: 215 # Pass the command onto the original handler. 216 return imp_smbserver.SMBCommands.smbComNtCreateAndX(conn_id, 217 smb_server, 218 smb_command, 219 recv_packet) 220 221 flags2 = recv_packet["Flags2"] 222 ncax_data = imp_smb.SMBNtCreateAndX_Data(flags=flags2, 223 data=smb_command[ 224 "Data"]) 225 requested_file = imp_smbserver.decodeSMBString( 226 flags2, 227 ncax_data["FileName"]) 228 log.debug("[SMB] User requested file '%s'", requested_file) 229 230 if path == SERVER_MAGIC: 231 fid, full_path = self.get_server_path(requested_file) 232 else: 233 assert path == TESTS_MAGIC 234 fid, full_path = self.get_test_path(requested_file) 235 236 self.tmpfiles.append(full_path) 237 238 resp_parms = imp_smb.SMBNtCreateAndXResponse_Parameters() 239 resp_data = "" 240 241 # Simple way to generate a fid 242 if len(conn_data["OpenedFiles"]) == 0: 243 fakefid = 1 244 else: 245 fakefid = conn_data["OpenedFiles"].keys()[-1] + 1 246 resp_parms["Fid"] = fakefid 247 resp_parms["CreateAction"] = disposition 248 249 if os.path.isdir(path): 250 resp_parms[ 251 "FileAttributes"] = imp_smb.SMB_FILE_ATTRIBUTE_DIRECTORY 252 resp_parms["IsDirectory"] = 1 253 else: 254 resp_parms["IsDirectory"] = 0 255 resp_parms["FileAttributes"] = ncax_parms["FileAttributes"] 256 257 # Get this file's information 258 resp_info, error_code = imp_smbserver.queryPathInformation( 259 os.path.dirname(full_path), os.path.basename(full_path), 260 level=imp_smb.SMB_QUERY_FILE_ALL_INFO) 261 262 if error_code != STATUS_SUCCESS: 263 raise SmbError(error_code, "Failed to query path info") 264 265 resp_parms["CreateTime"] = resp_info["CreationTime"] 266 resp_parms["LastAccessTime"] = resp_info[ 267 "LastAccessTime"] 268 resp_parms["LastWriteTime"] = resp_info["LastWriteTime"] 269 resp_parms["LastChangeTime"] = resp_info[ 270 "LastChangeTime"] 271 resp_parms["FileAttributes"] = resp_info[ 272 "ExtFileAttributes"] 273 resp_parms["AllocationSize"] = resp_info[ 274 "AllocationSize"] 275 resp_parms["EndOfFile"] = resp_info["EndOfFile"] 276 277 # Let's store the fid for the connection 278 # smbServer.log("Create file %s, mode:0x%x" % (pathName, mode)) 279 conn_data["OpenedFiles"][fakefid] = {} 280 conn_data["OpenedFiles"][fakefid]["FileHandle"] = fid 281 conn_data["OpenedFiles"][fakefid]["FileName"] = path 282 conn_data["OpenedFiles"][fakefid]["DeleteOnClose"] = False 283 284 except SmbError as s: 285 log.debug("[SMB] SmbError hit: %s", s) 286 error_code = s.error_code 287 resp_parms = "" 288 resp_data = "" 289 290 resp_cmd = imp_smb.SMBCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX) 291 resp_cmd["Parameters"] = resp_parms 292 resp_cmd["Data"] = resp_data 293 smb_server.setConnectionData(conn_id, conn_data) 294 295 return [resp_cmd], None, error_code 296 297 def get_share_path(self, conn_data, root_fid, tid): 298 conn_shares = conn_data["ConnectedShares"] 299 300 if tid in conn_shares: 301 if root_fid > 0: 302 # If we have a rootFid, the path is relative to that fid 303 path = conn_data["OpenedFiles"][root_fid]["FileName"] 304 log.debug("RootFid present %s!" % path) 305 else: 306 if "path" in conn_shares[tid]: 307 path = conn_shares[tid]["path"] 308 else: 309 raise SmbError(STATUS_ACCESS_DENIED, 310 "Connection share had no path") 311 else: 312 raise SmbError(imp_smbserver.STATUS_SMB_BAD_TID, 313 "TID was invalid") 314 315 return path 316 317 def get_server_path(self, requested_filename): 318 log.debug("[SMB] Get server path '%s'", requested_filename) 319 320 if requested_filename not in [VERIFIED_REQ]: 321 raise SmbError(STATUS_NO_SUCH_FILE, "Couldn't find the file") 322 323 fid, filename = tempfile.mkstemp() 324 log.debug("[SMB] Created %s (%d) for storing '%s'", 325 filename, fid, requested_filename) 326 327 contents = "" 328 329 if requested_filename == VERIFIED_REQ: 330 log.debug("[SMB] Verifying server is alive") 331 pid = os.getpid() 332 # see tests/server/util.c function write_pidfile 333 if os.name == "nt": 334 pid += 65536 335 contents = VERIFIED_RSP.format(pid=pid).encode('utf-8') 336 337 self.write_to_fid(fid, contents) 338 return fid, filename 339 340 def write_to_fid(self, fid, contents): 341 # Write the contents to file descriptor 342 os.write(fid, contents) 343 os.fsync(fid) 344 345 # Rewind the file to the beginning so a read gets us the contents 346 os.lseek(fid, 0, os.SEEK_SET) 347 348 def get_test_path(self, requested_filename): 349 log.info("[SMB] Get reply data from 'test%s'", requested_filename) 350 351 fid, filename = tempfile.mkstemp() 352 log.debug("[SMB] Created %s (%d) for storing test '%s'", 353 filename, fid, requested_filename) 354 355 try: 356 contents = self.ctd.get_test_data(requested_filename).encode('utf-8') 357 self.write_to_fid(fid, contents) 358 return fid, filename 359 360 except Exception: 361 log.exception("Failed to make test file") 362 raise SmbError(STATUS_NO_SUCH_FILE, "Failed to make test file") 363 364 365class SmbError(Exception): 366 def __init__(self, error_code, error_message): 367 super(SmbError, self).__init__(error_message) 368 self.error_code = error_code 369 370 371class ScriptRC(object): 372 """Enum for script return codes.""" 373 374 SUCCESS = 0 375 FAILURE = 1 376 EXCEPTION = 2 377 378 379class ScriptError(Exception): 380 pass 381 382 383def get_options(): 384 parser = argparse.ArgumentParser() 385 386 parser.add_argument("--port", action="store", default=9017, 387 type=int, help="port to listen on") 388 parser.add_argument("--host", action="store", default="127.0.0.1", 389 help="host to listen on") 390 parser.add_argument("--verbose", action="store", type=int, default=0, 391 help="verbose output") 392 parser.add_argument("--pidfile", action="store", 393 help="file name for the PID") 394 parser.add_argument("--logfile", action="store", 395 help="file name for the log") 396 parser.add_argument("--srcdir", action="store", help="test directory") 397 parser.add_argument("--id", action="store", help="server ID") 398 parser.add_argument("--ipv4", action="store_true", default=0, 399 help="IPv4 flag") 400 401 return parser.parse_args() 402 403 404def setup_logging(options): 405 """Set up logging from the command line options.""" 406 root_logger = logging.getLogger() 407 add_stdout = False 408 409 formatter = logging.Formatter("%(asctime)s %(levelname)-5.5s %(message)s") 410 411 # Write out to a logfile 412 if options.logfile: 413 handler = ClosingFileHandler(options.logfile) 414 handler.setFormatter(formatter) 415 handler.setLevel(logging.DEBUG) 416 root_logger.addHandler(handler) 417 else: 418 # The logfile wasn't specified. Add a stdout logger. 419 add_stdout = True 420 421 if options.verbose: 422 # Add a stdout logger as well in verbose mode 423 root_logger.setLevel(logging.DEBUG) 424 add_stdout = True 425 else: 426 root_logger.setLevel(logging.WARNING) 427 428 if add_stdout: 429 stdout_handler = logging.StreamHandler(sys.stdout) 430 stdout_handler.setFormatter(formatter) 431 stdout_handler.setLevel(logging.DEBUG) 432 root_logger.addHandler(stdout_handler) 433 434 435if __name__ == '__main__': 436 # Get the options from the user. 437 options = get_options() 438 439 # Setup logging using the user options 440 setup_logging(options) 441 442 # Run main script. 443 try: 444 rc = smbserver(options) 445 except Exception: 446 log.exception('Error in SMB server') 447 rc = ScriptRC.EXCEPTION 448 449 if options.pidfile and os.path.isfile(options.pidfile): 450 os.unlink(options.pidfile) 451 452 log.info("[SMB] Returning %d", rc) 453 sys.exit(rc) 454