xref: /curl/tests/smbserver.py (revision 895008de)
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#
4#  Project                     ___| | | |  _ \| |
5#                             / __| | | | |_) | |
6#                            | (__| |_| |  _ <| |___
7#                             \___|\___/|_| \_\_____|
8#
9# Copyright (C) Daniel Stenberg, <daniel@haxx.se>, et al.
10#
11# This software is licensed as described in the file COPYING, which
12# you should have received as part of this distribution. The terms
13# are also available at https://curl.se/docs/copyright.html.
14#
15# You may opt to use, copy, modify, merge, publish, distribute and/or sell
16# copies of the Software, and permit persons to whom the Software is
17# furnished to do so, under the terms of the COPYING file.
18#
19# This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY
20# KIND, either express or implied.
21#
22# SPDX-License-Identifier: curl
23#
24"""Server for testing SMB"""
25
26from __future__ import (absolute_import, division, print_function,
27                        unicode_literals)
28
29import argparse
30import logging
31import os
32import signal
33import sys
34import tempfile
35import threading
36
37# Import our curl test data helper
38from util import ClosingFileHandler, TestData
39
40if sys.version_info.major >= 3:
41    import configparser
42else:
43    import ConfigParser as configparser
44
45# impacket needs to be installed in the Python environment
46try:
47    import impacket
48except ImportError:
49    sys.stderr.write(
50        'Warning: Python package impacket is required for smb testing; '
51        'use pip or your package manager to install it\n')
52    sys.exit(1)
53from impacket import smb as imp_smb
54from impacket import smbserver as imp_smbserver
55from impacket.nt_errors import (STATUS_ACCESS_DENIED, STATUS_NO_SUCH_FILE,
56                                STATUS_SUCCESS)
57
58log = logging.getLogger(__name__)
59SERVER_MAGIC = "SERVER_MAGIC"
60TESTS_MAGIC = "TESTS_MAGIC"
61VERIFIED_REQ = "verifiedserver"
62VERIFIED_RSP = "WE ROOLZ: {pid}\n"
63
64
65class ShutdownHandler(threading.Thread):
66    """Cleanly shut down the SMB server
67
68    This can only be done from another thread while the server is in
69    serve_forever(), so a thread is spawned here that waits for a shutdown
70    signal before doing its thing. Use in a with statement around the
71    serve_forever() call.
72    """
73
74    def __init__(self, server):
75        super(ShutdownHandler, self).__init__()
76        self.server = server
77        self.shutdown_event = threading.Event()
78
79    def __enter__(self):
80        self.start()
81        signal.signal(signal.SIGINT, self._sighandler)
82        signal.signal(signal.SIGTERM, self._sighandler)
83
84    def __exit__(self, *_):
85        # Call for shutdown just in case it wasn't done already
86        self.shutdown_event.set()
87        # Wait for thread, and therefore also the server, to finish
88        self.join()
89        # Uninstall our signal handlers
90        signal.signal(signal.SIGINT, signal.SIG_DFL)
91        signal.signal(signal.SIGTERM, signal.SIG_DFL)
92        # Delete any temporary files created by the server during its run
93        log.info("Deleting %d temporary file(s)", len(self.server.tmpfiles))
94        for f in self.server.tmpfiles:
95            os.unlink(f)
96
97    def _sighandler(self, _signum, _frame):
98        # Wake up the cleanup task
99        self.shutdown_event.set()
100
101    def run(self):
102        # Wait for shutdown signal
103        self.shutdown_event.wait()
104        # Notify the server to shut down
105        self.server.shutdown()
106
107
108def smbserver(options):
109    """Start up a TCP SMB server that serves forever
110
111    """
112    if options.pidfile:
113        pid = os.getpid()
114        # see tests/server/util.c function write_pidfile
115        if os.name == "nt":
116            pid += 65536
117        with open(options.pidfile, "w") as f:
118            f.write(str(pid))
119
120    # Here we write a mini config for the server
121    smb_config = configparser.ConfigParser()
122    smb_config.add_section("global")
123    smb_config.set("global", "server_name", "SERVICE")
124    smb_config.set("global", "server_os", "UNIX")
125    smb_config.set("global", "server_domain", "WORKGROUP")
126    smb_config.set("global", "log_file", "None")
127    smb_config.set("global", "credentials_file", "")
128
129    # We need a share which allows us to test that the server is running
130    smb_config.add_section("SERVER")
131    smb_config.set("SERVER", "comment", "server function")
132    smb_config.set("SERVER", "read only", "yes")
133    smb_config.set("SERVER", "share type", "0")
134    smb_config.set("SERVER", "path", SERVER_MAGIC)
135
136    # Have a share for tests.  These files will be autogenerated from the
137    # test input.
138    smb_config.add_section("TESTS")
139    smb_config.set("TESTS", "comment", "tests")
140    smb_config.set("TESTS", "read only", "yes")
141    smb_config.set("TESTS", "share type", "0")
142    smb_config.set("TESTS", "path", TESTS_MAGIC)
143
144    if not options.srcdir or not os.path.isdir(options.srcdir):
145        raise ScriptException("--srcdir is mandatory")
146
147    test_data_dir = os.path.join(options.srcdir, "data")
148
149    smb_server = TestSmbServer((options.host, options.port),
150                               config_parser=smb_config,
151                               test_data_directory=test_data_dir)
152    log.info("[SMB] setting up SMB server on port %s", options.port)
153    smb_server.processConfigFile()
154
155    # Start a thread that cleanly shuts down the server on a signal
156    with ShutdownHandler(smb_server):
157        # This will block until smb_server.shutdown() is called
158        smb_server.serve_forever()
159
160    return 0
161
162
163class TestSmbServer(imp_smbserver.SMBSERVER):
164    """
165    Test server for SMB which subclasses the impacket SMBSERVER and provides
166    test functionality.
167    """
168
169    def __init__(self,
170                 address,
171                 config_parser=None,
172                 test_data_directory=None):
173        imp_smbserver.SMBSERVER.__init__(self,
174                                         address,
175                                         config_parser=config_parser)
176        self.tmpfiles = []
177
178        # Set up a test data object so we can get test data later.
179        self.ctd = TestData(test_data_directory)
180
181        # Override smbComNtCreateAndX so we can pretend to have files which
182        # don't exist.
183        self.hookSmbCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX,
184                            self.create_and_x)
185
186    def create_and_x(self, conn_id, smb_server, smb_command, recv_packet):
187        """
188        Our version of smbComNtCreateAndX looks for special test files and
189        fools the rest of the framework into opening them as if they were
190        normal files.
191        """
192        conn_data = smb_server.getConnectionData(conn_id)
193
194        # Wrap processing in a try block which allows us to throw SmbException
195        # to control the flow.
196        try:
197            ncax_parms = imp_smb.SMBNtCreateAndX_Parameters(
198                smb_command["Parameters"])
199
200            path = self.get_share_path(conn_data,
201                                       ncax_parms["RootFid"],
202                                       recv_packet["Tid"])
203            log.info("[SMB] Requested share path: %s", path)
204
205            disposition = ncax_parms["Disposition"]
206            log.debug("[SMB] Requested disposition: %s", disposition)
207
208            # Currently we only support reading files.
209            if disposition != imp_smb.FILE_OPEN:
210                raise SmbException(STATUS_ACCESS_DENIED,
211                                   "Only support reading files")
212
213            # Check to see if the path we were given is actually a
214            # magic path which needs generating on the fly.
215            if path not in [SERVER_MAGIC, TESTS_MAGIC]:
216                # Pass the command onto the original handler.
217                return imp_smbserver.SMBCommands.smbComNtCreateAndX(conn_id,
218                                                                    smb_server,
219                                                                    smb_command,
220                                                                    recv_packet)
221
222            flags2 = recv_packet["Flags2"]
223            ncax_data = imp_smb.SMBNtCreateAndX_Data(flags=flags2,
224                                                     data=smb_command[
225                                                         "Data"])
226            requested_file = imp_smbserver.decodeSMBString(
227                flags2,
228                ncax_data["FileName"])
229            log.debug("[SMB] User requested file '%s'", requested_file)
230
231            if path == SERVER_MAGIC:
232                fid, full_path = self.get_server_path(requested_file)
233            else:
234                assert (path == TESTS_MAGIC)
235                fid, full_path = self.get_test_path(requested_file)
236
237            self.tmpfiles.append(full_path)
238
239            resp_parms = imp_smb.SMBNtCreateAndXResponse_Parameters()
240            resp_data = ""
241
242            # Simple way to generate a fid
243            if len(conn_data["OpenedFiles"]) == 0:
244                fakefid = 1
245            else:
246                fakefid = conn_data["OpenedFiles"].keys()[-1] + 1
247            resp_parms["Fid"] = fakefid
248            resp_parms["CreateAction"] = disposition
249
250            if os.path.isdir(path):
251                resp_parms[
252                    "FileAttributes"] = imp_smb.SMB_FILE_ATTRIBUTE_DIRECTORY
253                resp_parms["IsDirectory"] = 1
254            else:
255                resp_parms["IsDirectory"] = 0
256                resp_parms["FileAttributes"] = ncax_parms["FileAttributes"]
257
258            # Get this file's information
259            resp_info, error_code = imp_smbserver.queryPathInformation(
260                os.path.dirname(full_path), os.path.basename(full_path),
261                level=imp_smb.SMB_QUERY_FILE_ALL_INFO)
262
263            if error_code != STATUS_SUCCESS:
264                raise SmbException(error_code, "Failed to query path info")
265
266            resp_parms["CreateTime"] = resp_info["CreationTime"]
267            resp_parms["LastAccessTime"] = resp_info[
268                "LastAccessTime"]
269            resp_parms["LastWriteTime"] = resp_info["LastWriteTime"]
270            resp_parms["LastChangeTime"] = resp_info[
271                "LastChangeTime"]
272            resp_parms["FileAttributes"] = resp_info[
273                "ExtFileAttributes"]
274            resp_parms["AllocationSize"] = resp_info[
275                "AllocationSize"]
276            resp_parms["EndOfFile"] = resp_info["EndOfFile"]
277
278            # Let's store the fid for the connection
279            # smbServer.log("Create file %s, mode:0x%x" % (pathName, mode))
280            conn_data["OpenedFiles"][fakefid] = {}
281            conn_data["OpenedFiles"][fakefid]["FileHandle"] = fid
282            conn_data["OpenedFiles"][fakefid]["FileName"] = path
283            conn_data["OpenedFiles"][fakefid]["DeleteOnClose"] = False
284
285        except SmbException as s:
286            log.debug("[SMB] SmbException hit: %s", s)
287            error_code = s.error_code
288            resp_parms = ""
289            resp_data = ""
290
291        resp_cmd = imp_smb.SMBCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX)
292        resp_cmd["Parameters"] = resp_parms
293        resp_cmd["Data"] = resp_data
294        smb_server.setConnectionData(conn_id, conn_data)
295
296        return [resp_cmd], None, error_code
297
298    def get_share_path(self, conn_data, root_fid, tid):
299        conn_shares = conn_data["ConnectedShares"]
300
301        if tid in conn_shares:
302            if root_fid > 0:
303                # If we have a rootFid, the path is relative to that fid
304                path = conn_data["OpenedFiles"][root_fid]["FileName"]
305                log.debug("RootFid present %s!" % path)
306            else:
307                if "path" in conn_shares[tid]:
308                    path = conn_shares[tid]["path"]
309                else:
310                    raise SmbException(STATUS_ACCESS_DENIED,
311                                       "Connection share had no path")
312        else:
313            raise SmbException(imp_smbserver.STATUS_SMB_BAD_TID,
314                               "TID was invalid")
315
316        return path
317
318    def get_server_path(self, requested_filename):
319        log.debug("[SMB] Get server path '%s'", requested_filename)
320
321        if requested_filename not in [VERIFIED_REQ]:
322            raise SmbException(STATUS_NO_SUCH_FILE, "Couldn't find the file")
323
324        fid, filename = tempfile.mkstemp()
325        log.debug("[SMB] Created %s (%d) for storing '%s'",
326                  filename, fid, requested_filename)
327
328        contents = ""
329
330        if requested_filename == VERIFIED_REQ:
331            log.debug("[SMB] Verifying server is alive")
332            pid = os.getpid()
333            # see tests/server/util.c function write_pidfile
334            if os.name == "nt":
335                pid += 65536
336            contents = VERIFIED_RSP.format(pid=pid).encode('utf-8')
337
338        self.write_to_fid(fid, contents)
339        return fid, filename
340
341    def write_to_fid(self, fid, contents):
342        # Write the contents to file descriptor
343        os.write(fid, contents)
344        os.fsync(fid)
345
346        # Rewind the file to the beginning so a read gets us the contents
347        os.lseek(fid, 0, os.SEEK_SET)
348
349    def get_test_path(self, requested_filename):
350        log.info("[SMB] Get reply data from 'test%s'", requested_filename)
351
352        fid, filename = tempfile.mkstemp()
353        log.debug("[SMB] Created %s (%d) for storing test '%s'",
354                  filename, fid, requested_filename)
355
356        try:
357            contents = self.ctd.get_test_data(requested_filename).encode('utf-8')
358            self.write_to_fid(fid, contents)
359            return fid, filename
360
361        except Exception:
362            log.exception("Failed to make test file")
363            raise SmbException(STATUS_NO_SUCH_FILE, "Failed to make test file")
364
365
366class SmbException(Exception):
367    def __init__(self, error_code, error_message):
368        super(SmbException, self).__init__(error_message)
369        self.error_code = error_code
370
371
372class ScriptRC(object):
373    """Enum for script return codes"""
374    SUCCESS = 0
375    FAILURE = 1
376    EXCEPTION = 2
377
378
379class ScriptException(Exception):
380    pass
381
382
383def get_options():
384    parser = argparse.ArgumentParser()
385
386    parser.add_argument("--port", action="store", default=9017,
387                      type=int, help="port to listen on")
388    parser.add_argument("--host", action="store", default="127.0.0.1",
389                      help="host to listen on")
390    parser.add_argument("--verbose", action="store", type=int, default=0,
391                        help="verbose output")
392    parser.add_argument("--pidfile", action="store",
393                        help="file name for the PID")
394    parser.add_argument("--logfile", action="store",
395                        help="file name for the log")
396    parser.add_argument("--srcdir", action="store", help="test directory")
397    parser.add_argument("--id", action="store", help="server ID")
398    parser.add_argument("--ipv4", action="store_true", default=0,
399                        help="IPv4 flag")
400
401    return parser.parse_args()
402
403
404def setup_logging(options):
405    """
406    Set up logging from the command line options
407    """
408    root_logger = logging.getLogger()
409    add_stdout = False
410
411    formatter = logging.Formatter("%(asctime)s %(levelname)-5.5s %(message)s")
412
413    # Write out to a logfile
414    if options.logfile:
415        handler = ClosingFileHandler(options.logfile)
416        handler.setFormatter(formatter)
417        handler.setLevel(logging.DEBUG)
418        root_logger.addHandler(handler)
419    else:
420        # The logfile wasn't specified. Add a stdout logger.
421        add_stdout = True
422
423    if options.verbose:
424        # Add a stdout logger as well in verbose mode
425        root_logger.setLevel(logging.DEBUG)
426        add_stdout = True
427    else:
428        root_logger.setLevel(logging.WARNING)
429
430    if add_stdout:
431        stdout_handler = logging.StreamHandler(sys.stdout)
432        stdout_handler.setFormatter(formatter)
433        stdout_handler.setLevel(logging.DEBUG)
434        root_logger.addHandler(stdout_handler)
435
436
437if __name__ == '__main__':
438    # Get the options from the user.
439    options = get_options()
440
441    # Setup logging using the user options
442    setup_logging(options)
443
444    # Run main script.
445    try:
446        rc = smbserver(options)
447    except Exception as e:
448        log.exception(e)
449        rc = ScriptRC.EXCEPTION
450
451    if options.pidfile and os.path.isfile(options.pidfile):
452        os.unlink(options.pidfile)
453
454    log.info("[SMB] Returning %d", rc)
455    sys.exit(rc)
456