xref: /curl/tests/smbserver.py (revision 57cc5233)
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#
4#  Project                     ___| | | |  _ \| |
5#                             / __| | | | |_) | |
6#                            | (__| |_| |  _ <| |___
7#                             \___|\___/|_| \_\_____|
8#
9# Copyright (C) Daniel Stenberg, <daniel@haxx.se>, et al.
10#
11# This software is licensed as described in the file COPYING, which
12# you should have received as part of this distribution. The terms
13# are also available at https://curl.se/docs/copyright.html.
14#
15# You may opt to use, copy, modify, merge, publish, distribute and/or sell
16# copies of the Software, and permit persons to whom the Software is
17# furnished to do so, under the terms of the COPYING file.
18#
19# This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY
20# KIND, either express or implied.
21#
22# SPDX-License-Identifier: curl
23#
24"""Server for testing SMB."""
25
26from __future__ import (absolute_import, division, print_function,
27                        unicode_literals)
28
29import argparse
30import logging
31import os
32import signal
33import sys
34import tempfile
35import threading
36
37# Import our curl test data helper
38from util import ClosingFileHandler, TestData
39
40if sys.version_info.major >= 3:
41    import configparser
42else:
43    import ConfigParser as configparser
44
45# impacket needs to be installed in the Python environment
46try:
47    import impacket  # noqa: F401
48except ImportError:
49    sys.stderr.write(
50        'Warning: Python package impacket is required for smb testing; '
51        'use pip or your package manager to install it\n')
52    sys.exit(1)
53from impacket import smb as imp_smb
54from impacket import smbserver as imp_smbserver
55from impacket.nt_errors import (STATUS_ACCESS_DENIED, STATUS_NO_SUCH_FILE,
56                                STATUS_SUCCESS)
57
58log = logging.getLogger(__name__)
59SERVER_MAGIC = "SERVER_MAGIC"
60TESTS_MAGIC = "TESTS_MAGIC"
61VERIFIED_REQ = "verifiedserver"
62VERIFIED_RSP = "WE ROOLZ: {pid}\n"
63
64
65class ShutdownHandler(threading.Thread):
66    """
67    Cleanly shut down the SMB server.
68
69    This can only be done from another thread while the server is in
70    serve_forever(), so a thread is spawned here that waits for a shutdown
71    signal before doing its thing. Use in a with statement around the
72    serve_forever() call.
73    """
74
75    def __init__(self, server):
76        super(ShutdownHandler, self).__init__()
77        self.server = server
78        self.shutdown_event = threading.Event()
79
80    def __enter__(self):
81        self.start()
82        signal.signal(signal.SIGINT, self._sighandler)
83        signal.signal(signal.SIGTERM, self._sighandler)
84
85    def __exit__(self, *_):
86        # Call for shutdown just in case it wasn't done already
87        self.shutdown_event.set()
88        # Wait for thread, and therefore also the server, to finish
89        self.join()
90        # Uninstall our signal handlers
91        signal.signal(signal.SIGINT, signal.SIG_DFL)
92        signal.signal(signal.SIGTERM, signal.SIG_DFL)
93        # Delete any temporary files created by the server during its run
94        log.info("Deleting %d temporary file(s)", len(self.server.tmpfiles))
95        for f in self.server.tmpfiles:
96            os.unlink(f)
97
98    def _sighandler(self, _signum, _frame):
99        # Wake up the cleanup task
100        self.shutdown_event.set()
101
102    def run(self):
103        # Wait for shutdown signal
104        self.shutdown_event.wait()
105        # Notify the server to shut down
106        self.server.shutdown()
107
108
109def smbserver(options):
110    """Start up a TCP SMB server that serves forever."""
111    if options.pidfile:
112        pid = os.getpid()
113        # see tests/server/util.c function write_pidfile
114        if os.name == "nt":
115            pid += 65536
116        with open(options.pidfile, "w") as f:
117            f.write(str(pid))
118
119    # Here we write a mini config for the server
120    smb_config = configparser.ConfigParser()
121    smb_config.add_section("global")
122    smb_config.set("global", "server_name", "SERVICE")
123    smb_config.set("global", "server_os", "UNIX")
124    smb_config.set("global", "server_domain", "WORKGROUP")
125    smb_config.set("global", "log_file", "None")
126    smb_config.set("global", "credentials_file", "")
127
128    # We need a share which allows us to test that the server is running
129    smb_config.add_section("SERVER")
130    smb_config.set("SERVER", "comment", "server function")
131    smb_config.set("SERVER", "read only", "yes")
132    smb_config.set("SERVER", "share type", "0")
133    smb_config.set("SERVER", "path", SERVER_MAGIC)
134
135    # Have a share for tests.  These files will be autogenerated from the
136    # test input.
137    smb_config.add_section("TESTS")
138    smb_config.set("TESTS", "comment", "tests")
139    smb_config.set("TESTS", "read only", "yes")
140    smb_config.set("TESTS", "share type", "0")
141    smb_config.set("TESTS", "path", TESTS_MAGIC)
142
143    if not options.srcdir or not os.path.isdir(options.srcdir):
144        raise ScriptError("--srcdir is mandatory")
145
146    test_data_dir = os.path.join(options.srcdir, "data")
147
148    smb_server = TestSmbServer((options.host, options.port),
149                               config_parser=smb_config,
150                               test_data_directory=test_data_dir)
151    log.info("[SMB] setting up SMB server on port %s", options.port)
152    smb_server.processConfigFile()
153
154    # Start a thread that cleanly shuts down the server on a signal
155    with ShutdownHandler(smb_server):
156        # This will block until smb_server.shutdown() is called
157        smb_server.serve_forever()
158
159    return 0
160
161
162class TestSmbServer(imp_smbserver.SMBSERVER):
163    """
164    Test server for SMB which subclasses the impacket SMBSERVER and provides
165    test functionality.
166    """
167
168    def __init__(self,
169                 address,
170                 config_parser=None,
171                 test_data_directory=None):
172        imp_smbserver.SMBSERVER.__init__(self,
173                                         address,
174                                         config_parser=config_parser)
175        self.tmpfiles = []
176
177        # Set up a test data object so we can get test data later.
178        self.ctd = TestData(test_data_directory)
179
180        # Override smbComNtCreateAndX so we can pretend to have files which
181        # don't exist.
182        self.hookSmbCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX,
183                            self.create_and_x)
184
185    def create_and_x(self, conn_id, smb_server, smb_command, recv_packet):
186        """
187        Our version of smbComNtCreateAndX looks for special test files and
188        fools the rest of the framework into opening them as if they were
189        normal files.
190        """
191        conn_data = smb_server.getConnectionData(conn_id)
192
193        # Wrap processing in a try block which allows us to throw SmbError
194        # to control the flow.
195        try:
196            ncax_parms = imp_smb.SMBNtCreateAndX_Parameters(
197                smb_command["Parameters"])
198
199            path = self.get_share_path(conn_data,
200                                       ncax_parms["RootFid"],
201                                       recv_packet["Tid"])
202            log.info("[SMB] Requested share path: %s", path)
203
204            disposition = ncax_parms["Disposition"]
205            log.debug("[SMB] Requested disposition: %s", disposition)
206
207            # Currently we only support reading files.
208            if disposition != imp_smb.FILE_OPEN:
209                raise SmbError(STATUS_ACCESS_DENIED,
210                                   "Only support reading files")
211
212            # Check to see if the path we were given is actually a
213            # magic path which needs generating on the fly.
214            if path not in [SERVER_MAGIC, TESTS_MAGIC]:
215                # Pass the command onto the original handler.
216                return imp_smbserver.SMBCommands.smbComNtCreateAndX(conn_id,
217                                                                    smb_server,
218                                                                    smb_command,
219                                                                    recv_packet)
220
221            flags2 = recv_packet["Flags2"]
222            ncax_data = imp_smb.SMBNtCreateAndX_Data(flags=flags2,
223                                                     data=smb_command[
224                                                         "Data"])
225            requested_file = imp_smbserver.decodeSMBString(
226                flags2,
227                ncax_data["FileName"])
228            log.debug("[SMB] User requested file '%s'", requested_file)
229
230            if path == SERVER_MAGIC:
231                fid, full_path = self.get_server_path(requested_file)
232            else:
233                assert path == TESTS_MAGIC
234                fid, full_path = self.get_test_path(requested_file)
235
236            self.tmpfiles.append(full_path)
237
238            resp_parms = imp_smb.SMBNtCreateAndXResponse_Parameters()
239            resp_data = ""
240
241            # Simple way to generate a fid
242            if len(conn_data["OpenedFiles"]) == 0:
243                fakefid = 1
244            else:
245                fakefid = conn_data["OpenedFiles"].keys()[-1] + 1
246            resp_parms["Fid"] = fakefid
247            resp_parms["CreateAction"] = disposition
248
249            if os.path.isdir(path):
250                resp_parms[
251                    "FileAttributes"] = imp_smb.SMB_FILE_ATTRIBUTE_DIRECTORY
252                resp_parms["IsDirectory"] = 1
253            else:
254                resp_parms["IsDirectory"] = 0
255                resp_parms["FileAttributes"] = ncax_parms["FileAttributes"]
256
257            # Get this file's information
258            resp_info, error_code = imp_smbserver.queryPathInformation(
259                os.path.dirname(full_path), os.path.basename(full_path),
260                level=imp_smb.SMB_QUERY_FILE_ALL_INFO)
261
262            if error_code != STATUS_SUCCESS:
263                raise SmbError(error_code, "Failed to query path info")
264
265            resp_parms["CreateTime"] = resp_info["CreationTime"]
266            resp_parms["LastAccessTime"] = resp_info[
267                "LastAccessTime"]
268            resp_parms["LastWriteTime"] = resp_info["LastWriteTime"]
269            resp_parms["LastChangeTime"] = resp_info[
270                "LastChangeTime"]
271            resp_parms["FileAttributes"] = resp_info[
272                "ExtFileAttributes"]
273            resp_parms["AllocationSize"] = resp_info[
274                "AllocationSize"]
275            resp_parms["EndOfFile"] = resp_info["EndOfFile"]
276
277            # Let's store the fid for the connection
278            # smbServer.log("Create file %s, mode:0x%x" % (pathName, mode))
279            conn_data["OpenedFiles"][fakefid] = {}
280            conn_data["OpenedFiles"][fakefid]["FileHandle"] = fid
281            conn_data["OpenedFiles"][fakefid]["FileName"] = path
282            conn_data["OpenedFiles"][fakefid]["DeleteOnClose"] = False
283
284        except SmbError as s:
285            log.debug("[SMB] SmbError hit: %s", s)
286            error_code = s.error_code
287            resp_parms = ""
288            resp_data = ""
289
290        resp_cmd = imp_smb.SMBCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX)
291        resp_cmd["Parameters"] = resp_parms
292        resp_cmd["Data"] = resp_data
293        smb_server.setConnectionData(conn_id, conn_data)
294
295        return [resp_cmd], None, error_code
296
297    def get_share_path(self, conn_data, root_fid, tid):
298        conn_shares = conn_data["ConnectedShares"]
299
300        if tid in conn_shares:
301            if root_fid > 0:
302                # If we have a rootFid, the path is relative to that fid
303                path = conn_data["OpenedFiles"][root_fid]["FileName"]
304                log.debug("RootFid present %s!" % path)
305            else:
306                if "path" in conn_shares[tid]:
307                    path = conn_shares[tid]["path"]
308                else:
309                    raise SmbError(STATUS_ACCESS_DENIED,
310                                       "Connection share had no path")
311        else:
312            raise SmbError(imp_smbserver.STATUS_SMB_BAD_TID,
313                               "TID was invalid")
314
315        return path
316
317    def get_server_path(self, requested_filename):
318        log.debug("[SMB] Get server path '%s'", requested_filename)
319
320        if requested_filename not in [VERIFIED_REQ]:
321            raise SmbError(STATUS_NO_SUCH_FILE, "Couldn't find the file")
322
323        fid, filename = tempfile.mkstemp()
324        log.debug("[SMB] Created %s (%d) for storing '%s'",
325                  filename, fid, requested_filename)
326
327        contents = ""
328
329        if requested_filename == VERIFIED_REQ:
330            log.debug("[SMB] Verifying server is alive")
331            pid = os.getpid()
332            # see tests/server/util.c function write_pidfile
333            if os.name == "nt":
334                pid += 65536
335            contents = VERIFIED_RSP.format(pid=pid).encode('utf-8')
336
337        self.write_to_fid(fid, contents)
338        return fid, filename
339
340    def write_to_fid(self, fid, contents):
341        # Write the contents to file descriptor
342        os.write(fid, contents)
343        os.fsync(fid)
344
345        # Rewind the file to the beginning so a read gets us the contents
346        os.lseek(fid, 0, os.SEEK_SET)
347
348    def get_test_path(self, requested_filename):
349        log.info("[SMB] Get reply data from 'test%s'", requested_filename)
350
351        fid, filename = tempfile.mkstemp()
352        log.debug("[SMB] Created %s (%d) for storing test '%s'",
353                  filename, fid, requested_filename)
354
355        try:
356            contents = self.ctd.get_test_data(requested_filename).encode('utf-8')
357            self.write_to_fid(fid, contents)
358            return fid, filename
359
360        except Exception:
361            log.exception("Failed to make test file")
362            raise SmbError(STATUS_NO_SUCH_FILE, "Failed to make test file")
363
364
365class SmbError(Exception):
366    def __init__(self, error_code, error_message):
367        super(SmbError, self).__init__(error_message)
368        self.error_code = error_code
369
370
371class ScriptRC(object):
372    """Enum for script return codes."""
373
374    SUCCESS = 0
375    FAILURE = 1
376    EXCEPTION = 2
377
378
379class ScriptError(Exception):
380    pass
381
382
383def get_options():
384    parser = argparse.ArgumentParser()
385
386    parser.add_argument("--port", action="store", default=9017,
387                      type=int, help="port to listen on")
388    parser.add_argument("--host", action="store", default="127.0.0.1",
389                      help="host to listen on")
390    parser.add_argument("--verbose", action="store", type=int, default=0,
391                        help="verbose output")
392    parser.add_argument("--pidfile", action="store",
393                        help="file name for the PID")
394    parser.add_argument("--logfile", action="store",
395                        help="file name for the log")
396    parser.add_argument("--srcdir", action="store", help="test directory")
397    parser.add_argument("--id", action="store", help="server ID")
398    parser.add_argument("--ipv4", action="store_true", default=0,
399                        help="IPv4 flag")
400
401    return parser.parse_args()
402
403
404def setup_logging(options):
405    """Set up logging from the command line options."""
406    root_logger = logging.getLogger()
407    add_stdout = False
408
409    formatter = logging.Formatter("%(asctime)s %(levelname)-5.5s %(message)s")
410
411    # Write out to a logfile
412    if options.logfile:
413        handler = ClosingFileHandler(options.logfile)
414        handler.setFormatter(formatter)
415        handler.setLevel(logging.DEBUG)
416        root_logger.addHandler(handler)
417    else:
418        # The logfile wasn't specified. Add a stdout logger.
419        add_stdout = True
420
421    if options.verbose:
422        # Add a stdout logger as well in verbose mode
423        root_logger.setLevel(logging.DEBUG)
424        add_stdout = True
425    else:
426        root_logger.setLevel(logging.WARNING)
427
428    if add_stdout:
429        stdout_handler = logging.StreamHandler(sys.stdout)
430        stdout_handler.setFormatter(formatter)
431        stdout_handler.setLevel(logging.DEBUG)
432        root_logger.addHandler(stdout_handler)
433
434
435if __name__ == '__main__':
436    # Get the options from the user.
437    options = get_options()
438
439    # Setup logging using the user options
440    setup_logging(options)
441
442    # Run main script.
443    try:
444        rc = smbserver(options)
445    except Exception:
446        log.exception('Error in SMB server')
447        rc = ScriptRC.EXCEPTION
448
449    if options.pidfile and os.path.isfile(options.pidfile):
450        os.unlink(options.pidfile)
451
452    log.info("[SMB] Returning %d", rc)
453    sys.exit(rc)
454