Source code for tlsfuzzer.expect

# Author: Hubert Kario, (c) 2015
# Released under Gnu GPL v2.0, see LICENSE file for details

"""Parsing and processing of received TLS messages"""
from __future__ import print_function

import itertools
from functools import partial
import sys
import time

import tlslite.utils.tlshashlib as hashlib
from tlslite.constants import ContentType, HandshakeType, CertificateType,\
        HashAlgorithm, SignatureAlgorithm, ExtensionType,\
        SSL2HandshakeType, CipherSuite, GroupName, AlertDescription, \
        SignatureScheme, TLS_1_3_HRR, HeartbeatMode, \
        TLS_1_1_DOWNGRADE_SENTINEL, TLS_1_2_DOWNGRADE_SENTINEL, \
        HeartbeatMessageType, ClientCertificateType, CertificateStatusType
from tlslite.messages import ServerHello, Certificate, ServerHelloDone,\
        ChangeCipherSpec, Finished, Alert, CertificateRequest, ServerHello2,\
        ServerKeyExchange, ClientHello, ServerFinished, CertificateStatus, \
        CertificateVerify, EncryptedExtensions, NewSessionTicket, Heartbeat,\
        KeyUpdate, HelloRequest, NewSessionTicket1_0
from tlslite.extensions import TLSExtension, ALPNExtension
from tlslite.utils.codec import Parser, Writer
from tlslite.utils.compat import b2a_hex
from tlslite.utils.cryptomath import secureHMAC, derive_secret, \
        HKDF_expand_label
from tlslite.mathtls import RFC7919_GROUPS, FFDHE_PARAMETERS, calc_key
from tlslite.keyexchange import KeyExchange, DHE_RSAKeyExchange, \
        ECDHE_RSAKeyExchange
from tlslite.x509 import X509
from tlslite.x509certchain import X509CertChain
from tlslite.errors import TLSDecryptionFailed
from tlslite.handshakehashes import HandshakeHashes
from tlslite.handshakehelpers import HandshakeHelpers
from .handshake_helpers import calc_pending_states, kex_for_group, \
        curve_name_to_hash_tls13
from .helpers import ECDSA_SIG_TLS1_3_ALL
from .tree import TreeNode

# pylint: disable=import-error,no-name-in-module
# pylint: disable=bad-option-value,deprecated-class
if sys.version_info >= (3, 3):
    from collections.abc import Iterable
else:
    from collections import Iterable
# pylint: enable=bad-option-value,deprecated-class
# pylint: enable=import-error,no-name-in-module


[docs] class Expect(TreeNode): """Base class for objects handling message readers""" def __init__(self, content_type): """Prepare the class for handling tree graph""" super(Expect, self).__init__() self.content_type = content_type
[docs] def is_expect(self): """Flag to tell if the object is a message processor""" return True
[docs] def is_command(self): """Flag to tell that the object is a message processor""" return False
[docs] def is_generator(self): """Flag to tell that the object is not a message generator""" return False
[docs] def is_match(self, msg): """ Checks if the object can handle message Note that the msg is a raw, unparsed message of indicated type that requires calling write() to get a raw bytearray() representation of it :type msg: tlslite.messages.Message :param msg: raw message to check """ if msg.contentType == self.content_type: return True return False
[docs] def process(self, state, msg): """ Process the message and update the state accordingly. :type state: tlsfuzzer.runner.ConnectionState :param state: current connection state, needs to be updated after parsing the message by inheriting classes :type msg: tlslite.messages.Message :param msg: raw message to parse """ raise NotImplementedError("Subclasses need to implement this!")
[docs] class ExpectMessage(Expect): """Common methods for handling TLS messages."""
[docs] @staticmethod def _cmp_eq(our, recv, field_type=None, f_str=None): """ Check if expected value matched received, if defined. If our is not None, compare with recv. If they don't match, try translating them with field_type.toStr() method and rise AssertionError with message formatted with f_str. First parameter to .format() will be expected value and the second one will be the received one """ if our is None or our == recv: return if field_type: expected = field_type.toStr(our) received = field_type.toStr(recv) else: expected = our received = recv if not f_str: f_str = "Expected: {0}, received: {1}" raise AssertionError(f_str.format(expected, received))
[docs] @classmethod def _cmp_eq_or_in(cls, our, recv, field_type=None, f_str=None): """ Check if received value equals expected or is in expected list. If our is a list or set, check if recv is in it. If our is not None, check if it's equal to recv. If they don't match or are not part of a set, try translating them with field_type.toStr() method and raise AssertionError formatted with f_str. First parameter to .format() will be the expected value and the second one witll be the received one. """ if our is None: return try: if recv in our: return except TypeError: return cls._cmp_eq(our, recv, field_type, f_str) # doesn't match, so prepare the error message if field_type: expected = "({0})".format(", ".join( field_type.toStr(i) for i in our)) received = field_type.toStr(recv) else: expected = our received = recv if not f_str: f_str = "Received value ({1}) not in expected list: {0}" raise AssertionError(f_str.format(expected, received))
[docs] @staticmethod def _cmp_eq_list(our, recv, field_type=None, f_str=None): """ Check if expected list of values matched received, if defined. If our is not None, compare with recv. If they don't match, try translating items in the lists with field_type.toStr() method and rise AssertionError with message formatted with f_str. First parameter to .format() will be list of expected values and the second one will be the received one """ if our is None or our == recv: return if field_type: expected = ", ".join(field_type.toStr(i) for i in our) expected = "({0})".format(expected) received = ", ".join(field_type.toStr(i) for i in recv) received = "({0})".format(received) else: expected = repr(our) received = repr(recv) if not f_str: f_str = "Expected: {0}, received: {1}" raise AssertionError(f_str.format(expected, received))
[docs] class ExpectHandshake(ExpectMessage): """Common methods for handling TLS Handshake protocol messages""" def __init__(self, content_type, handshake_type): """ Set the type of message :type content_type: int :type handshake_type: int """ super(ExpectHandshake, self).__init__(content_type) self.handshake_type = handshake_type
[docs] def is_match(self, msg): """Check if message is a given type of handshake protocol message""" if not super(ExpectHandshake, self).is_match(msg): return False if not msg.write(): # if message is empty return False hs_type = Parser(msg.write()).get(1) if hs_type != self.handshake_type: return False return True
[docs] def process(self, state, msg): raise NotImplementedError("Subclass need to implement this!")
[docs] def srv_ext_handler_ems(state, extension): """Process Extended Master Secret extension from server.""" if extension.extData: raise AssertionError("Malformed EMS extension, data in payload") state.extended_master_secret = True
[docs] def srv_ext_handler_etm(state, extension): """Process Encrypt then MAC extension from server.""" if extension.extData: raise AssertionError("Malformed EtM extension, data in payload") state.encrypt_then_mac = True
[docs] def srv_ext_handler_sni(state, extension): """Process the server_name extension from server.""" del state # kept for comatibility if extension.extData: raise AssertionError("Malformed SNI extenion, data in payload")
[docs] def srv_ext_handler_renego(state, extension): """Process the renegotiation_info from server.""" if extension.renegotiated_connection != \ state.key['client_verify_data'] + state.key['server_verify_data']: raise AssertionError("Invalid data in renegotiation_info")
[docs] def srv_ext_handler_alpn(state, extension): """Process the ALPN extension from server.""" cln_hello = state.get_last_message_of_type(ClientHello) cln_ext = cln_hello.getExtension(ExtensionType.alpn) # the sent extension might have been provided with explicit encoding cln_ext = ALPNExtension().parse(Parser(cln_ext.extData)) if not extension.protocol_names or len(extension.protocol_names) != 1: raise AssertionError("Malformed ALPN extension") if extension.protocol_names[0] not in cln_ext.protocol_names: raise AssertionError("Server selected ALPN protocol we did not " "advertise")
[docs] def srv_ext_handler_ec_point(state, extension): """Process the ec_point_formats extension from server.""" del state if extension.formats is None or not extension.formats: raise AssertionError("Malformed ec_point_formats extension")
[docs] def srv_ext_handler_npn(state, extension): """Process the NPN extension from server.""" del state if extension.protocols is None or not extension.protocols: raise AssertionError("Malformed NPN extension")
[docs] def srv_ext_handler_session_ticket(state, extension): """Process the session_ticket extension from server.""" del state if extension.ticket != b"": raise AssertionError("Malformed session_ticket extension")
[docs] def srv_ext_handler_key_share(state, extension): """Process the key_share extension from server.""" cln_hello = state.get_last_message_of_type(ClientHello) cln_ext = cln_hello.getExtension(ExtensionType.key_share) group_id = extension.server_share.group cl_ext = next((i for i in cln_ext.client_shares if i.group == group_id), None) if cl_ext is None: raise AssertionError("Server selected group we didn't advertise: {0}" .format(GroupName.toStr(group_id))) kex = kex_for_group(group_id, state.version) state.key['ServerHello.extensions.key_share.key_exchange'] = \ extension.server_share.key_exchange if not cl_ext.private: raise ValueError("private value for key share of group {0} missing" .format(GroupName.toStr(group_id))) z = kex.calc_shared_key(cl_ext.private, extension.server_share.key_exchange) state.key['DH shared secret'] = z
[docs] def hrr_ext_handler_key_share(state, extension): """Process the key_share extension in HRR message.""" cln_hello = state.get_last_message_of_type(ClientHello) cln_ext = cln_hello.getExtension(ExtensionType.supported_groups) group_id = extension.selected_group if group_id not in cln_ext.groups: raise AssertionError("Server selected group we didn't advertise: {0}" .format(GroupName.toStr(group_id)))
[docs] def srv_ext_handler_supp_vers(state, extension): """Process the supported_versions from server.""" cln_hello = state.get_last_message_of_type(ClientHello) cln_ext = cln_hello.getExtension(ExtensionType.supported_versions) vers = extension.version if vers not in cln_ext.versions: raise AssertionError("Server selected version we didn't advertise: {0}" .format(vers)) state.version = vers
[docs] def srv_ext_handler_supp_groups(state, extension): """Process the supported_groups from server.""" del state if not extension.groups: raise AssertionError("Server did not send any supported_groups")
[docs] def srv_ext_handler_status_request(state, extension): """ Process the status_request extension from server. TLS 1.2 ServerHello specific, in TLS 1.3 the extension resides in Certificate message. """ del state if extension.status_type is not None or \ extension.responder_id_list != [] or \ extension.request_extensions != bytearray(): raise AssertionError("Server did send non empty status_request " "extension")
[docs] def srv_ext_handler_heartbeat(state, extension): """Process the heartbeat extension from server.""" del state if not extension.mode: raise AssertionError("Empty mode in heartbeat extension.") if extension.mode != HeartbeatMode.PEER_ALLOWED_TO_SEND and \ extension.mode != HeartbeatMode.PEER_NOT_ALLOWED_TO_SEND: raise AssertionError("Invalid mode in heartbeat extension.")
[docs] def _srv_ext_handler_psk(state, extension, psk_configs): """Process the pre_shared_key extension from server. Since it needs the psk_configurations, it can't do it automatically so it shouldn't be part of _srv_ext_handler. """ cln_hello = state.get_last_message_of_type(ClientHello) cln_ext = cln_hello.getExtension(ExtensionType.pre_shared_key) # the selection is 0-based if extension.selected >= len(cln_ext.identities): raise AssertionError("Server selected PSK we didn't send") ident = cln_ext.identities[extension.selected].identity if state.session_tickets: nst = state.session_tickets[-1] if nst.ticket == ident: state.key['PSK secret'] = HandshakeHelpers.calc_res_binder_psk( cln_ext.identities[extension.selected], state.key['resumption master secret'], [nst]) return secret = next((i[1] for i in psk_configs if i[0] == ident), None) if not secret: raise ValueError("psk_configs are missing identity") state.key['PSK secret'] = secret
[docs] def gen_srv_ext_handler_psk(psk_configs=tuple()): """Creates a handler for pre_shared_key extension from the server.""" return partial(_srv_ext_handler_psk, psk_configs=psk_configs)
[docs] def _srv_ext_handler_record_limit(state, extension, size=None): """Process record_size_limit extension from server.""" cln_hello = state.get_last_message_of_type(ClientHello) cln_ext = cln_hello.getExtension(ExtensionType.record_size_limit) assert extension.record_size_limit is not None assert 64 <= extension.record_size_limit <= 2**14 + \ int(state.version > (3, 3)) if size and extension.record_size_limit != size: raise AssertionError("Server sent unexpected size in extension, " "expected size: {0}, received size: {1}" .format(size, extension.record_size_limit)) if state.version <= (3, 3): # in TLS 1.2 and earlier we need to delay that to processing of # server CCS state._peer_record_size_limit = extension.record_size_limit state._our_record_size_limit = min(2**14, cln_ext.record_size_limit) else: # in TLS 1.3 we need to implement it right away (as the extension # applies only to encrypted messages) # the RecordLayer expects value that excludes content type state.msg_sock.recv_record_limit = min( 2**14, cln_ext.record_size_limit-1) # this is just hint for padding callback state.msg_sock.send_record_limit = min( 2**14, extension.record_size_limit-1) # this guides fragmentation state.msg_sock.recordSize = state.msg_sock.send_record_limit
[docs] def gen_srv_ext_handler_record_limit(size=None): """ Create a handler for record_size_limit_extension from the server. Note that if the extension is actually negotiated, it will override any `~SetMaxRecordSize()` before EncryptedExtensions in TLS 1.3 and before ChangeCipherSpec in TLS 1.2 and earlier. :param int size: expected value from server, None for any valid """ return partial(_srv_ext_handler_record_limit, size=size)
[docs] def clnt_ext_handler_status_request(state, extension): """ Check status_request extension from initiating side. To be used in ClientHello and CertificateRequest """ del state # kept for compatibility if extension.status_type != CertificateStatusType.ocsp: raise AssertionError( "Unexpected status_type in status_request extension: {0}" .format(CertificateStatusType.toStr(extension.status_type))) if extension.responder_id_list is None \ or extension.request_extensions is None: raise AssertionError( "Malformed status_request extension")
[docs] def clnt_ext_handler_sig_algs(state, extension): """ Check signature_algorithms or signature_algorithms_cert extension. To be used in ClientHello and CertificateRequest. """ del state # kept for API compatibility if not extension.sigalgs: raise AssertionError( "Empty or malformed {0} extension" .format(ExtensionType.toStr(extension.extType)))
_srv_ext_handler = \ {ExtensionType.extended_master_secret: srv_ext_handler_ems, ExtensionType.encrypt_then_mac: srv_ext_handler_etm, ExtensionType.server_name: srv_ext_handler_sni, ExtensionType.renegotiation_info: srv_ext_handler_renego, ExtensionType.alpn: srv_ext_handler_alpn, ExtensionType.session_ticket: srv_ext_handler_session_ticket, ExtensionType.ec_point_formats: srv_ext_handler_ec_point, ExtensionType.supports_npn: srv_ext_handler_npn, ExtensionType.key_share: srv_ext_handler_key_share, ExtensionType.supported_versions: srv_ext_handler_supp_vers, ExtensionType.heartbeat: srv_ext_handler_heartbeat, ExtensionType.record_size_limit: _srv_ext_handler_record_limit, ExtensionType.status_request: srv_ext_handler_status_request} _HRR_EXT_HANDLER = \ {ExtensionType.key_share: hrr_ext_handler_key_share, ExtensionType.cookie: hrr_ext_handler_cookie} _EE_EXT_HANDLER = \ {ExtensionType.server_name: srv_ext_handler_sni, ExtensionType.alpn: srv_ext_handler_alpn, ExtensionType.supported_groups: srv_ext_handler_supp_groups, ExtensionType.heartbeat: srv_ext_handler_heartbeat, ExtensionType.record_size_limit: _srv_ext_handler_record_limit} _CR_EXT_HANDLER = \ {ExtensionType.status_request: clnt_ext_handler_status_request, ExtensionType.signature_algorithms: clnt_ext_handler_sig_algs, ExtensionType.signature_algorithms_cert: clnt_ext_handler_sig_algs}
[docs] class _ExpectExtensionsMessage(ExpectHandshake): """ Common methods of messages that have a list of extensions. Used in ServerHello, EncryptedExtensions and CertificateRequest (in TLS 1.3) """ def __init__(self, content_type, msg_type, extensions): super(_ExpectExtensionsMessage, self).__init__( content_type, msg_type) self.extensions = extensions
[docs] def _compare_extensions(self, message): """ Verify that server provided extensions match exactly expected list. """ # if the list of extensions is present, make sure it matches exactly # with what the server sent if self.extensions and not message.extensions: raise AssertionError("Server did not send any extensions") if self.extensions is not None and message.extensions: expected = set(self.extensions.keys()) got = set(i.extType for i in message.extensions) if got != expected: diff = expected.difference(got) if diff: raise AssertionError("Server did not send extension(s): " "{0}".format( ", ".join((ExtensionType.toStr(i) for i in diff)))) diff = got.difference(expected) # we already checked if got != expected so diff here # must be non-empty if the one checked above is assert diff raise AssertionError("Server sent unexpected extension(s):" " {0}".format( ", ".join(ExtensionType.toStr(i) for i in diff)))
[docs] class ExpectServerHello(_ExpectExtensionsMessage): """ Parsing TLS Handshake protocol Server Hello messages. Processing of the ServerHello message updates the record layer to the version advertisied by the server. Use :py:class:`~tlsfuzzer.messages.SetRecordVersion` to change it earlier to send records with different versions. .. note:: Receiving of the ServerHello in TLS 1.3 influences record layer encryption. After the message is received, the ``client_handshake_traffic_secret`` and ``server_handshake_traffic_secret`` is derived and record layer is configured to expect encrypted records on the *receiving* side. :ivar str ~.description: identifier to print when processing of the node fails """ def __init__(self, extensions=None, version=None, resume=False, cipher=None, server_max_protocol=None, force_resume=False, description=None): """ Initialize the object :param dict extensions: extension objects to match the server sent extensions or callbacks to process and verify them. None means use automatic handlers that will verify the response against the extensions sent in ClientHello. Empty dict means that the server is expected to send no extensions. Order does not matter, but all extensions present and only extensions present in the list must be sent by server. None as the value of the relevant extension type can be used to select autohandler for a given extension type. :param tuple version: the literal version in the Server Hello message (needs to be (3, 3) for TLS 1.3, use extensions to expect TLS 1.3 negotiation) :param tuple server_max_protocol: the higher protocol version supported by server. Used for testing downgrade signaling of servers. :type cipher: int or set-like :param int cipher: the id of the cipher that is expected to be negotiated by server. Can also be a list or set (needs to support ``in``) for a set of allowed ciphers. None (the default) means any valid cipher (i.e. not SCSV or GREASE) sent in ClientHello can be selected by server. :type resume: boolean :param resume: whether the session id should match the one from current state - IOW, if the server hello should belong to a resumed session. TLS 1.2 and earlier only. In TLS 1.3 resumption is handled by providing handler for ``pre_shared_key`` extension. :param boolean force_resume: assume that the session is getting resumed, even if the sessionID is empty. Applicable to TLS 1.2 and earlier only when using session tickets and not sending a sessionID. """ super(ExpectServerHello, self).__init__(ContentType.handshake, HandshakeType.server_hello, extensions) self.cipher = cipher self.version = version self.resume = resume self.srv_max_prot = server_max_protocol self.force_resume = force_resume self.description = description def __str__(self): """Return human redable representation of the object.""" if self.description: return "ExpectServerHello(description={0!r})"\ .format(self.description) return "ExpectServerHello()"
[docs] @staticmethod def _get_autohandler(ext_id): try: return _srv_ext_handler[ext_id] except KeyError: raise AssertionError("No autohandler for " "{0}" .format(ExtensionType .toStr(ext_id)))
[docs] def _process_extensions(self, state, cln_hello, srv_hello): """Check if extensions are correct.""" # extensions allowed in TLS 1.3 ServerHello and HelloRetryRequest # messages (as some need to be echoed by server in EncryptedExtensions # and some in Certificate) sh_supported = [ExtensionType.pre_shared_key, ExtensionType.supported_versions, ExtensionType.key_share] hrr_supported = [ExtensionType.cookie, ExtensionType.supported_versions, ExtensionType.key_share] for ext in srv_hello.extensions: ext_id = ext.extType if state.version > (3, 3) and \ ((srv_hello.random != TLS_1_3_HRR and ext_id not in sh_supported) or (srv_hello.random == TLS_1_3_HRR and ext_id not in hrr_supported)): raise AssertionError("Server sent unallowed " "extension of type {0}" .format(ExtensionType .toStr(ext_id))) # in TLS 1.2 generally the server can reply to any client sent # extension, and all of them end in ClientHello cl_ext = cln_hello.getExtension(ext_id) if ext_id == ExtensionType.renegotiation_info and \ CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV \ in cln_hello.cipher_suites: cl_ext = True if isinstance(self, ExpectHelloRetryRequest) and \ ext_id == ExtensionType.cookie: cl_ext = True if cl_ext is None: raise AssertionError("Server sent unadvertised " "extension of type {0}" .format(ExtensionType .toStr(ext_id))) handler = None if self.extensions: handler = self.extensions[ext_id] # use automatic handlers for some extensions if handler is None: handler = self._get_autohandler(ext_id) if callable(handler): handler(state, ext) elif isinstance(handler, TLSExtension): if not handler == ext: raise AssertionError("Expected extension not " "matched for type {0}, " "received: {1}" .format(ExtensionType .toStr(ext_id), ext)) else: raise ValueError("Bad extension handler for id {0}" .format(ExtensionType.toStr(ext_id)))
[docs] @staticmethod def _extract_version(msg): """Extract the real version from the message if TLS 1.3 is in use.""" ext = msg.getExtension(ExtensionType.supported_versions) # RFC 8446 "legacy_version field MUST be set to 0x0303" if msg.server_version > (3, 3): raise ValueError("Server sent invalid version in legacy_version " "field") if ext and msg.server_version == (3, 3): return ext.version return msg.server_version
[docs] def process(self, state, msg): """ Process the message and update state accordingly :type state: ConnectionState :param state: overall state of TLS connection :type msg: Message :param msg: TLS Message read from socket """ assert msg.contentType == ContentType.handshake parser = Parser(msg.write()) hs_type = parser.get(1) assert hs_type == HandshakeType.server_hello srv_hello = ServerHello() srv_hello.parse(parser) # extract important info state.server_random = srv_hello.random cln_hello = state.get_last_message_of_type(ClientHello) # check for session_id based session resumption if self.resume: assert state.session_id == srv_hello.session_id if self.force_resume or ((state.session_id == srv_hello.session_id or cln_hello.session_id == srv_hello.session_id) and srv_hello.session_id != bytearray(0) and self._extract_version(srv_hello) < (3, 4)): # TLS 1.2 resumption, TLS 1.3 is based on PSKs state.resuming = True assert state.cipher == srv_hello.cipher_suite assert state.version == self._extract_version(srv_hello) state.session_id = srv_hello.session_id self._cmp_eq(self.version, srv_hello.server_version, f_str="Server selected unexpected protocol version. " "Expected: {0}, received: {1}.") self._cmp_eq_or_in( self.cipher, srv_hello.cipher_suite, f_str="Server selected unexpected ciphersuite. " "Expected: {0}, received: {1}.") # check if server sent cipher matches what we advertised in CH if srv_hello.cipher_suite not in cln_hello.cipher_suites: cipher = srv_hello.cipher_suite if cipher in CipherSuite.ietfNames: name = "{0} ({1:#06x})".format(CipherSuite.ietfNames[cipher], cipher) else: name = "{0:#06x}".format(cipher) raise AssertionError("Server responded with cipher we did" " not advertise: {0}".format(name)) state.cipher = srv_hello.cipher_suite state.version = self._extract_version(srv_hello) # update the state of connection state.msg_sock.version = state.version state.msg_sock.tls13record = state.version > (3, 3) self._check_against_hrr(state, srv_hello) state.handshake_messages.append(srv_hello) state.handshake_hashes.update(msg.write()) # Reset value of the session-wide settings state.extended_master_secret = False state.encrypt_then_mac = False self._check_downgrade_protection(srv_hello) self._compare_extensions(srv_hello) if srv_hello.extensions: self._process_extensions(state, cln_hello, srv_hello) if state.version > (3, 3): self._setup_tls13_handshake_keys(state) return srv_hello
[docs] @staticmethod def _check_against_hrr(state, srv_hello): if state.version < (3, 4): return hrr = state.get_last_message_of_type(ServerHello) if not hrr or hrr.random != TLS_1_3_HRR: # not an HRR, so HRR tests don't apply to it return if hrr.cipher_suite != srv_hello.cipher_suite: raise AssertionError("Server picked different cipher suite than " "it advertised in HelloRetryRequest") hrr_version = hrr.getExtension(ExtensionType.supported_versions) sh_version = srv_hello.getExtension(ExtensionType.supported_versions) if hrr_version.version != sh_version.version: raise AssertionError("Server picked different protocol version " "than it advertised in HelloRetryRequest")
[docs] def _setup_tls13_handshake_keys(self, state): """Set up the encryption keys for the TLS 1.3 handshake.""" del self prf_name = state.prf_name prf_size = state.prf_size # Derive PSK secret psk = state.key.setdefault('PSK secret', bytearray(prf_size)) # Derive TLS 1.3 early secret secret = bytearray(prf_size) secret = secureHMAC(secret, psk, prf_name) state.key['early secret'] = secret # Derive TLS 1.3 handshake secret secret = derive_secret(secret, b'derived', None, prf_name) dh_secret = state.key.setdefault('DH shared secret', bytearray(prf_size)) secret = secureHMAC(secret, dh_secret, prf_name) state.key['handshake secret'] = secret # Derive TLS 1.3 traffic secrets s_traffic_secret = derive_secret(secret, b's hs traffic', state.handshake_hashes, prf_name) state.key['server handshake traffic secret'] = s_traffic_secret c_traffic_secret = derive_secret(secret, b'c hs traffic', state.handshake_hashes, prf_name) state.key['client handshake traffic secret'] = c_traffic_secret state.msg_sock.calcTLS1_3PendingState( state.cipher, c_traffic_secret, s_traffic_secret, None) state.msg_sock.changeReadState()
[docs] def _check_downgrade_protection(self, srv_hello): """ Verify that server provided downgrade protection as specified in RFC 8446, Section 4.1.3 """ # even if we don't know which version server supports, some values # are obviously incorrect: if (self._extract_version(srv_hello) > (3, 3) and srv_hello.random[24:] == TLS_1_2_DOWNGRADE_SENTINEL) or \ (self._extract_version(srv_hello) > (3, 2) and srv_hello.random[24:] == TLS_1_1_DOWNGRADE_SENTINEL): raise AssertionError( "Server set downgrade protection sentinel but shouldn't " "have done that") # as we're doing both TLS 1.2 tests and TLS 1.3 tests with `scripts/` # we don't know when setting the sentinel is expected and when # it is not as the negotiation might have ended up with TLS 1.2 # because that was the highest version we advertised if self.srv_max_prot is None: return downgrade_value = None if self.srv_max_prot > (3, 3) \ and self._extract_version(srv_hello) == (3, 3): downgrade_value = TLS_1_2_DOWNGRADE_SENTINEL elif self.srv_max_prot > (3, 2) \ and self._extract_version(srv_hello) < (3, 3): downgrade_value = TLS_1_1_DOWNGRADE_SENTINEL else: if srv_hello.random[24:] == TLS_1_1_DOWNGRADE_SENTINEL or \ srv_hello.random[24:] == TLS_1_2_DOWNGRADE_SENTINEL: raise AssertionError( "Server set downgrade protection sentinel but shouldn't " "have done that") if downgrade_value is not None: if srv_hello.random[24:] != downgrade_value: raise AssertionError( "Server failed to set downgrade protection sentinel in " "ServerHello.random value")
[docs] class ExpectHelloRetryRequest(ExpectServerHello): """Processing of the TLS 1.3 HelloRetryRequest message.""" def __init__(self, extensions=None, version=None, cipher=None): super(ExpectHelloRetryRequest, self).__init__( extensions, version, cipher) self._ch_hh = None self._msg = None
[docs] def process(self, state, msg): self._ch_hh = state.handshake_hashes.copy() self._msg = msg hrr = super(ExpectHelloRetryRequest, self).process(state, msg) assert hrr.random == TLS_1_3_HRR
[docs] @staticmethod def _get_autohandler(ext_id): try: return _HRR_EXT_HANDLER[ext_id] except KeyError: try: return _srv_ext_handler[ext_id] except KeyError: raise AssertionError("No autohandler for {0}".format( ExtensionType.toStr(ext_id)))
[docs] def _setup_tls13_handshake_keys(self, state): """Prepare handshake ciphers for the HRR handling""" prf_name = state.prf_name ch_hash = self._ch_hh.digest(prf_name) new_hh = HandshakeHashes() writer = Writer() writer.add(HandshakeType.message_hash, 1) writer.addVarSeq(ch_hash, 1, 3) new_hh.update(writer.bytes) new_hh.update(self._msg.write()) state.handshake_hashes = new_hh
[docs] class ExpectServerHello2(ExpectHandshake): """Processing of SSLv2 Handshake Protocol SERVER-HELLO message""" def __init__(self, version=None): c_type = ContentType.handshake h_type = SSL2HandshakeType.server_hello super(ExpectServerHello2, self).__init__(c_type, h_type) self.version = version
[docs] def process(self, state, msg): """ Process the message and update state accordingly :type state: `~ConnectionState` :param state: overall state of TLS connection :type msg: Message :param msg: TLS Message read from socket """ # the value is faked for SSLv2 protocol, but let's just check sanity assert msg.contentType == ContentType.handshake parser = Parser(msg.write()) hs_type = parser.get(1) assert hs_type == SSL2HandshakeType.server_hello server_hello = ServerHello2().parse(parser) state.handshake_messages.append(server_hello) state.handshake_hashes.update(msg.write()) self._cmp_eq(self.version, server_hello.server_version, f_str="Server picked unexpected protocol version." "Expected: {0}, received: {1}.") if server_hello.session_id_hit: state.resuming = True state.session_id = server_hello.session_id state.server_random = server_hello.session_id state.version = server_hello.server_version state.msg_sock.version = server_hello.server_version # fake a certificate message so finding the server public key works x509 = X509() x509.parseBinary(server_hello.certificate) cert_chain = X509CertChain([x509]) certificate = Certificate(CertificateType.x509) certificate.create(cert_chain) state.handshake_messages.append(certificate)
# fake message so don't update handshake hashes
[docs] class ExpectCertificate(ExpectHandshake): """Processing TLS Handshake protocol Certificate messages""" def __init__(self, cert_type=CertificateType.x509): super(ExpectCertificate, self).__init__(ContentType.handshake, HandshakeType.certificate) self.cert_type = cert_type self._old_cert = None self._old_cert_bytes = None
[docs] def process(self, state, msg): """ :type state: `~ConnectionState` """ assert msg.contentType == ContentType.handshake msg_bytes = msg.write() if self._old_cert_bytes is not None and \ msg_bytes == self._old_cert_bytes: cert = self._old_cert else: parser = Parser(msg_bytes) hs_type = parser.get(1) assert hs_type == HandshakeType.certificate cert = Certificate(self.cert_type, state.version) cert.parse(parser) self._old_cert_bytes = msg_bytes self._old_cert = cert state.handshake_messages.append(cert) state.handshake_hashes.update(msg_bytes)
[docs] class ExpectCertificateVerify(ExpectHandshake): """ Processing TLS Handshake protocol Certificate Verify messages. :param tuple(int,int) version: Expected TLS version of the message. If not provided will be taken from the state. :param tuple(int,int) sig_alg: Expected value of the signature scheme created by the server. If not provided it will be compared with signature algorithm extension from client hello. :param str hash_file: The file where hashes of the signature context will be logged :param str sig_file: The file where the signatures themselves will be logged """ def __init__( self, version=None, sig_alg=None, hash_file=None, sig_file=None ): super(ExpectCertificateVerify, self).__init__( ContentType.handshake, HandshakeType.certificate_verify) self.version = version self.sig_alg = sig_alg self.hash_file = hash_file self.sig_file = sig_file
[docs] def process(self, state, msg): """ :type state: `~ConnectionState` """ assert msg.contentType == ContentType.handshake parser = Parser(msg.write()) hs_type = parser.get(1) assert hs_type == HandshakeType.certificate_verify if self.version is None: self.version = state.version cert_v = CertificateVerify(self.version) cert_v.parse(parser) if self.sig_alg: assert self.sig_alg == cert_v.signatureAlgorithm else: c_hello = state.get_last_message_of_type(ClientHello) ext = c_hello.getExtension(ExtensionType.signature_algorithms) assert cert_v.signatureAlgorithm in ext.sigalgs key_type = state.get_server_public_key().key_type if key_type == "rsa-pss": # in TLS 1.3 only RSA-PSS signatures are allowed assert cert_v.signatureAlgorithm in ( SignatureScheme.rsa_pss_pss_sha256, SignatureScheme.rsa_pss_pss_sha384, SignatureScheme.rsa_pss_pss_sha512) elif key_type == "rsa": # in TLS 1.3 only RSA-PSS signatures are allowed assert cert_v.signatureAlgorithm in ( SignatureScheme.rsa_pss_rsae_sha256, SignatureScheme.rsa_pss_rsae_sha384, SignatureScheme.rsa_pss_rsae_sha512) elif key_type in ("Ed25519", "Ed448"): assert cert_v.signatureAlgorithm in ( SignatureScheme.ed25519, SignatureScheme.ed448) if getattr(SignatureScheme, key_type.lower()) != \ cert_v.signatureAlgorithm: raise AssertionError( "Mismatched signature ({0}) for used key ({1})" .format( SignatureScheme.toStr(cert_v.signatureAlgorithm), key_type)) else: assert key_type == "ecdsa" curve_name = state.get_server_public_key().curve_name assert curve_name in ("NIST256p", "NIST384p", "NIST521p") sigalg = cert_v.signatureAlgorithm assert sigalg in ECDSA_SIG_TLS1_3_ALL hash_name = curve_name_to_hash_tls13(curve_name) # in TLS 1.3 the hash is bound to key curve if sigalg != (getattr(HashAlgorithm, hash_name), SignatureAlgorithm.ecdsa): raise AssertionError( "Invalid signature type for {1} key, " "received: {0}" .format(SignatureScheme.toStr(sigalg), curve_name)) salg = cert_v.signatureAlgorithm if salg in (SignatureScheme.ed25519, SignatureScheme.ed448): hash_name = "intrinsic" padding = None salt_len = None elif salg[1] == SignatureAlgorithm.ecdsa: hash_name = HashAlgorithm.toStr(salg[0]) padding = None salt_len = None else: scheme = SignatureScheme.toRepr(salg) hash_name = SignatureScheme.getHash(scheme) padding = SignatureScheme.getPadding(scheme) salt_len = getattr(hashlib, hash_name)().digest_size transcript_hash = state.handshake_hashes.digest(state.prf_name) sig_context = bytearray(b'\x20' * 64 + b'TLS 1.3, server CertificateVerify' + b'\x00') + transcript_hash if not state.get_server_public_key().hashAndVerify( cert_v.signature, sig_context, padding, hash_name, salt_len): raise AssertionError("Signature verification failed") if self.hash_file: data = getattr(hashlib, hash_name)(sig_context).digest() self.hash_file.write(data) if self.sig_file: self.sig_file.write(cert_v.signature) state.handshake_messages.append(cert_v) state.handshake_hashes.update(msg.write())
[docs] class ExpectServerKeyExchange(ExpectHandshake): """Processing TLS Handshake protocol Server Key Exchange message""" def __init__(self, version=None, cipher_suite=None, valid_sig_algs=None, valid_groups=None, valid_params=None): """ Expect ServerKeyExchange message from server. :param list(int) valid_groups: TLS group identifiers for groups that server can use. In case the groups include identifiers between 256 and 512 (see RFC 7919), the node will also check that the server selected FFDH parameters match the parameters specified in the RFC. :param set(tuple(int,int)) valid_params: set of explicit expected parameters used by the server, the first element of the tuple is the expected generator and the second is the prime used for the DH calculation. Applicable only to ciphersuites that use FFDHE key exchange. """ msg_type = HandshakeType.server_key_exchange super(ExpectServerKeyExchange, self).__init__(ContentType.handshake, msg_type) self.version = version self.cipher_suite = cipher_suite self.valid_sig_algs = valid_sig_algs self.valid_groups = valid_groups self.valid_params = valid_params if self.valid_groups and self.valid_params: raise ValueError("valid_groups and valid_params are exclusive")
[docs] def _checkParams(self, server_key_exchange): groups = [] if self.valid_groups and any(i in range(256, 512) for i in self.valid_groups): groups = [RFC7919_GROUPS[i - 256] for i in self.valid_groups if i in range(256, 512)] if self.valid_params: groups = self.valid_params server_params = (server_key_exchange.dh_g, server_key_exchange.dh_p) if groups and server_params not in groups: for name, params in FFDHE_PARAMETERS.items(): if server_params == params: raise AssertionError( "DH parameters not from valid set, " "received: {0}".format(name)) raise AssertionError( "DH parameters not from valid set, " "received: g:{0}, p:{1}".format( hex(server_params[0]), hex(server_params[1])))
[docs] def process(self, state, msg): """Process the Server Key Exchange message""" assert msg.contentType == ContentType.handshake parser = Parser(msg.write()) hs_type = parser.get(1) assert hs_type == HandshakeType.server_key_exchange if self.version is None: self.version = state.version if self.cipher_suite is None: self.cipher_suite = state.cipher valid_sig_algs = self.valid_sig_algs valid_groups = self.valid_groups server_key_exchange = ServerKeyExchange(self.cipher_suite, self.version) server_key_exchange.parse(parser) client_random = state.client_random server_random = state.server_random public_key = state.get_server_public_key() server_hello = state.get_last_message_of_type(ServerHello) if server_hello is None: server_hello = ServerHello server_hello.server_version = state.version if valid_sig_algs is None: # if the value was unset in script, get the advertised value from # Client Hello client_hello = state.get_last_message_of_type(ClientHello) if client_hello is not None: sig_algs_ext = client_hello.getExtension(ExtensionType. signature_algorithms) if sig_algs_ext is not None: valid_sig_algs = sig_algs_ext.sigalgs if valid_sig_algs is None: # no advertised means support for sha1 only valid_sig_algs = [(HashAlgorithm.sha1, SignatureAlgorithm.rsa)] if self.cipher_suite in CipherSuite.ecdheEcdsaSuites: valid_sig_algs = [(HashAlgorithm.sha1, SignatureAlgorithm.ecdsa)] try: KeyExchange.verifyServerKeyExchange(server_key_exchange, public_key, client_random, server_random, valid_sig_algs) except TLSDecryptionFailed: # very rarely validation of signature fails, print it so that # we have a chance in debugging it print("Bad signature: {0}" .format(b2a_hex(server_key_exchange.signature)), file=sys.stderr) raise if self.cipher_suite in CipherSuite.dhAllSuites: self._checkParams(server_key_exchange) state.key_exchange = DHE_RSAKeyExchange(self.cipher_suite, clientHello=None, serverHello=server_hello, privateKey=None) state.key['ServerKeyExchange.key_share'] = \ server_key_exchange.dh_Ys state.key['ServerKeyExchange.dh_p'] = server_key_exchange.dh_p elif self.cipher_suite in CipherSuite.ecdhAllSuites: # extract valid groups from Client Hello if valid_groups is None: client_hello = state.get_last_message_of_type(ClientHello) if client_hello is not None: groups_ext = client_hello.getExtension(ExtensionType. supported_groups) if groups_ext is not None: valid_groups = groups_ext.groups if valid_groups is None: # no advertised means support for all valid_groups = GroupName.allEC state.key_exchange = \ ECDHE_RSAKeyExchange(self.cipher_suite, clientHello=None, serverHello=server_hello, privateKey=None, acceptedCurves=valid_groups) state.key['ServerKeyExchange.key_share'] = \ server_key_exchange.ecdh_Ys else: raise AssertionError("Unsupported cipher selected") state.key['premaster_secret'] = state.key_exchange.\ processServerKeyExchange(public_key, server_key_exchange) state.handshake_messages.append(server_key_exchange) state.handshake_hashes.update(msg.write())
# RFC8446 Section 4.2 says that implementation MUST reject extensions # it recognises but which are not allowed in CertificateRequest # check it against all defined in RFC8446 TLS_1_3_CR_FORBIDDEN = set(( ExtensionType.server_name, 1, # ExtensionType.max_fragment_length ExtensionType.supported_groups, 14, # ExtensionType.use_srtp ExtensionType.heartbeat, ExtensionType.alpn, 19, # ExtensionType.client_certificate_type 20, # ExtensionType.server_certificate_type 21, # ExtensionType.padding, ExtensionType.key_share, ExtensionType.pre_shared_key, ExtensionType.psk_key_exchange_modes, ExtensionType.early_data, ExtensionType.cookie, ExtensionType.supported_versions, 49 # ExtensionType.post_handshake_auth ))
[docs] class ExpectCertificateRequest(_ExpectExtensionsMessage): """Processing TLS Handshake protocol Certificate Request message.""" def __init__(self, sig_algs=None, cert_types=None, sanity_check_cert_types=True, extensions=None, context=None): """ Set expected parameters for the CertificateRequest message. :param sig_algs: a list of signature algorithms that we are expecting from server. Needs to be in-order and complete. ``None`` to accept any list from server. Applicable to TLS 1.2 and later only. Do not use together with non-default ``extensions``. :param cert_types: a list of client certificate types that we are expecting from server. Needs to be in-order and complete. ``None`` to accept any list from server. Applicable to TLS 1.2 and earlier only. :param sanity_check_cert_types: set to ``False`` to disable verification checking if every signature algorithm has a corresponding client certificate type. :param extensions: dictionary with extensions that need to be included in the message. Set to ``None`` to accept any, set to empty dict to expect no extensions. Usable in TLS 1.3 only. """ msg_type = HandshakeType.certificate_request super(ExpectCertificateRequest, self).__init__(ContentType.handshake, msg_type, extensions) self.sig_algs = sig_algs self.cert_types = cert_types self.context = context self.sanity_check_cert_types = sanity_check_cert_types if sig_algs is not None and extensions is not None: raise ValueError("Can't set sig_algs and extensions at the same " "time")
[docs] @staticmethod def _sanity_check_cert_types(cert_request): """Verify that the CertificateRequest is self-consistent.""" for sig_alg in cert_request.supported_signature_algs: if sig_alg[1] in (SignatureAlgorithm.ecdsa, SignatureAlgorithm.ed25519, SignatureAlgorithm.ed448): key_type = "ECDSA" cert_type = "ecdsa_sign" elif sig_alg[1] == SignatureAlgorithm.rsa: key_type = "RSA" cert_type = "rsa_sign" elif sig_alg[1] == SignatureAlgorithm.dsa: key_type = "DSA" cert_type = "dss_sign" else: sig_scheme = SignatureScheme.toRepr(sig_alg) key_type = SignatureScheme.getKeyType(sig_scheme) assert key_type == "rsa", \ "Unsupported signature algorithm: {0}".format(sig_alg) cert_type = "rsa_sign" if getattr(ClientCertificateType, cert_type) \ not in cert_request.certificate_types: raise AssertionError( "CertificateRequest includes {1} signature algorithms " "({0}) but does not include {2} client " "certificate type".format(sig_alg, key_type, cert_type))
[docs] @staticmethod def _get_autohandler(ext_id): try: return _CR_EXT_HANDLER[ext_id] except KeyError: # handle future/GREASE extensions return None
[docs] def _process_extensions(self, state, msg): for ext in msg.extensions: ext_id = ext.extType handler = None if ext_id in TLS_1_3_CR_FORBIDDEN: raise AssertionError( "Server sent extension that is explicitly forbidden in " "CertificateRequest messages: {0}".format( ExtensionType.toStr(ext_id))) if self.extensions: handler = self.extensions[ext_id] if handler is None: handler = self._get_autohandler(ext_id) if callable(handler): handler(state, ext) elif isinstance(handler, TLSExtension): if not handler == ext: raise AssertionError( "Expected extension not matched for type {0}, " "received: {1}".format(ExtensionType.toStr(ext_id), ext)) elif handler is None: # since server can send arbitrary extensions, we need to # be able to process them, so if the self.extensions is unset # we can just do nothing pass else: raise ValueError("Bad extension handler for id {0}".format( ExtensionType.toStr(ext_id)))
[docs] def process(self, state, msg): """ Check received Certificate Request :type state: ConnectionState """ assert msg.contentType == ContentType.handshake parser = Parser(msg.write()) hs_type = parser.get(1) assert hs_type == HandshakeType.certificate_request cert_request = CertificateRequest(state.version) cert_request.parse(parser) self._cmp_eq_list(self.sig_algs, cert_request.supported_signature_algs, SignatureScheme, f_str="Unexpected signature algorithms. Got: {1}, " "expected: {0}") self._cmp_eq_list(self.cert_types, cert_request.certificate_types, ClientCertificateType, f_str="Unexpected client certificate types. Got: " "{1}, expected: {0}") if state.version == (3, 3) and self.sanity_check_cert_types: # only in TLS 1.2 do the sig algs coexist with cert types self._sanity_check_cert_types(cert_request) if state.version >= (3, 4): self._compare_extensions(cert_request) self._process_extensions(state, cert_request) if self.context is not None: self.context.append(cert_request) state.handshake_messages.append(cert_request) state.handshake_hashes.update(msg.write())
[docs] class ExpectServerHelloDone(ExpectHandshake): """Processing TLS Handshake protocol ServerHelloDone messages""" def __init__(self): super(ExpectServerHelloDone, self).__init__(ContentType.handshake, HandshakeType.server_hello_done)
[docs] def process(self, state, msg): """ :type state: ConnectionState :type msg: Message """ assert msg.contentType == ContentType.handshake parser = Parser(msg.write()) hs_type = parser.get(1) assert hs_type == HandshakeType.server_hello_done srv_hello_done = ServerHelloDone() srv_hello_done.parse(parser) state.handshake_messages.append(srv_hello_done) state.handshake_hashes.update(msg.write())
[docs] class ExpectChangeCipherSpec(Expect): """ Processing TLS Change Cipher Spec messages. .. note:: In SSLv3 up to TLS 1.2, the message modifies the state of record layer to expect encrypted records *after* receiving this message. In case of renegotiation, record layer will expect records encrypted with the newly negotiated keys. In TLS 1.3 it has no effect on record layer encryption. """ def __init__(self): super(ExpectChangeCipherSpec, self).__init__(ContentType.change_cipher_spec)
[docs] def process(self, state, msg): """ :type state: ConnectionState :type msg: Message """ assert msg.contentType == ContentType.change_cipher_spec parser = Parser(msg.write()) ccs = ChangeCipherSpec().parse(parser) assert ccs.type == 1 if state.version < (3, 4): # in TLS 1.3 the CCS does not have any affect on encryption if state.resuming: state.msg_sock.encryptThenMAC = state.encrypt_then_mac calc_pending_states(state) state.msg_sock.changeReadState() if state._our_record_size_limit: state.msg_sock.recv_record_limit = state._our_record_size_limit
[docs] class ExpectVerify(ExpectHandshake): """Processing of SSLv2 SERVER-VERIFY message""" def __init__(self): super(ExpectVerify, self).__init__(ContentType.handshake, SSL2HandshakeType.server_verify)
[docs] def process(self, state, msg): """Check if the VERIFY message has expected value""" assert msg.contentType == ContentType.handshake parser = Parser(msg.write()) msg_type = parser.get(1) assert msg_type == SSL2HandshakeType.server_verify
[docs] class ExpectFinished(ExpectHandshake): """ Processing TLS handshake protocol Finished message. .. note:: In TLS 1.3 the message will modify record layer to start *sending* records with encryption using the ``client_handshake_traffic_secret`` keys. It will also modify the record layer to start expecting the records to be encrypted with ``server_application_traffic_secret`` keys. """ def __init__(self, version=None, description=None): """ Initialize object. .. note:: The ``description`` parameter MUST be specified as a keyword argument, i.e. read the definition as ``(self, *, description=None)`` (see PEP 3102). Otherwise the behaviour of this node is not guaranteed if new arguments are added to it (as they will be added *before* the ``description`` argument). :param str description: name or comment attached to the node, it will be printed when :py:func:`str` or :py:func:`repr` is called on the node. """ if version in ((0, 2), (2, 0)): super(ExpectFinished, self).__init__(ContentType.handshake, SSL2HandshakeType. server_finished) else: super(ExpectFinished, self).__init__(ContentType.handshake, HandshakeType.finished) self.version = version self.description = description
[docs] def process(self, state, msg): """ :type state: ConnectionState :type msg: Message """ assert msg.contentType == ContentType.handshake parser = Parser(msg.write()) hs_type = parser.get(1) assert hs_type == self.handshake_type if self.version is None: self.version = state.version if self.version in ((0, 2), (2, 0)): finished = ServerFinished() else: finished = Finished(self.version, state.prf_size) finished.parse(parser) if self.version in ((0, 2), (2, 0)): state.session_id = finished.verify_data elif self.version <= (3, 3): verify_expected = calc_key(state.version, state.key['master_secret'], state.cipher, b'client finished' if not state.client else b'server finished', state.handshake_hashes, output_length=12) assert finished.verify_data == verify_expected else: # TLS 1.3 finished_key = HKDF_expand_label( state.key['server handshake traffic secret'], b'finished', b'', state.prf_size, state.prf_name) transcript_hash = state.handshake_hashes.digest(state.prf_name) verify_expected = secureHMAC(finished_key, transcript_hash, state.prf_name) assert finished.verify_data == verify_expected state.handshake_messages.append(finished) state.key['server_verify_data'] = finished.verify_data state.handshake_hashes.update(msg.write()) if self.version in ((0, 2), (2, 0)): state.msg_sock.handshake_finished = True if self.version > (3, 3): # in TLS 1.3 ChangeCipherSpec is a no-op, so we need to attach # the change for reading to some message that is always sent state.msg_sock.changeWriteState() # we now need to calculate application traffic keys to allow # correct interpretation of the alerts regarding Certificate, # CertificateVerify and Finished # derive the master secret secret = derive_secret( state.key['handshake secret'], b'derived', None, state.prf_name) secret = secureHMAC( secret, bytearray(state.prf_size), state.prf_name) state.key['master secret'] = secret # derive encryption keys c_traff_sec = derive_secret( secret, b'c ap traffic', state.handshake_hashes, state.prf_name) state.key['client application traffic secret'] = c_traff_sec s_traff_sec = derive_secret( secret, b's ap traffic', state.handshake_hashes, state.prf_name) state.key['server application traffic secret'] = s_traff_sec # derive TLS exporter key exp_ms = derive_secret(secret, b'exp master', state.handshake_hashes, state.prf_name) state.key['exporter master secret'] = exp_ms # set up the encryption keys for application data state.msg_sock.calcTLS1_3PendingState( state.cipher, c_traff_sec, s_traff_sec, None) state.msg_sock.changeReadState()
def __repr__(self): """Return human readable representation of the object.""" return self._repr(['description'])
[docs] class ExpectEncryptedExtensions(_ExpectExtensionsMessage): """Processing of the TLS handshake protocol Encrypted Extensions message""" def __init__(self, extensions=None): super(ExpectEncryptedExtensions, self).__init__( ContentType.handshake, HandshakeType.encrypted_extensions, extensions)
[docs] def _compare_extensions_in_ee(self, srv_exts, cln_hello): """ Verify that server provided extensions match exactly expected list. """ # check if received extensions match the set extensions self._compare_extensions(srv_exts) if self.extensions is None and srv_exts.extensions: cln_exts = set(i.extType for i in cln_hello.extensions) got = set(i.extType for i in srv_exts.extensions) diff = got.difference(cln_exts) if not got.issubset(cln_exts): raise AssertionError("Server sent unexpected extension(s):" " {0}".format( ", ".join(ExtensionType.toStr(i) for i in diff)))
[docs] @staticmethod def _get_autohandler(ext_id): try: return _EE_EXT_HANDLER[ext_id] except KeyError: raise ValueError("No autohandler for " "{0}" .format(ExtensionType .toStr(ext_id)))
[docs] def _process_extensions(self, state, srv_exts): """Check if extensions are correct.""" # fix these constants, when the extensions are implemented ee_supported = [ExtensionType.server_name, 1, # max_fragment_length - RFC 6066 ExtensionType.supported_groups, 14, # use_srtp - RFC 5764 ExtensionType.heartbeat, # RFC 6520 ExtensionType.alpn, 19, # client_certificate_type # draft-ietf-tls-tls13-28 / RFC 7250 20, # server_certificate_type # draft-ietf-tls-tls13-28 / RFC 7250 ExtensionType.record_size_limit, # RFC 8449 ExtensionType.early_data] for ext in srv_exts.extensions: ext_id = ext.extType if ext_id not in ee_supported: raise AssertionError("Server sent unsupported " "extension of type {0}" .format(ExtensionType .toStr(ext_id))) handler = None if self.extensions: handler = self.extensions[ext_id] # use automatic handlers for some extensions if handler is None: handler = self._get_autohandler(ext_id) if callable(handler): handler(state, ext) elif isinstance(handler, TLSExtension): if not handler == ext: raise AssertionError("Expected extension not " "matched for type {0}, " "received: {1}" .format(ExtensionType .toStr(ext_id), ext)) else: raise ValueError("Bad extension handler for id {0}" .format(ExtensionType.toStr(ext_id)))
[docs] def process(self, state, msg): assert msg.contentType == ContentType.handshake parser = Parser(msg.write()) hs_type = parser.get(1) assert hs_type == self.handshake_type srv_exts = EncryptedExtensions().parse(parser) # get client_hello message with CH extensions cln_hello = state.get_last_message_of_type(ClientHello) self._compare_extensions_in_ee(srv_exts, cln_hello) if srv_exts.extensions: self._process_extensions(state, srv_exts) state.handshake_messages.append(srv_exts) state.handshake_hashes.update(msg.write())
[docs] class ExpectNewSessionTicket(ExpectHandshake): """Processing TLS handshake protocol new session ticket message.""" def __init__(self, version=None, description=None): """ Initialise object. .. note:: The ``description`` parameter MUST be specified as a keyword argument, i.e. read the definition as ``(self, *, description=None)`` (see PEP 3102). Otherwise the behaviour of this node is not guaranteed if new arguments are added to it (as they will be added *before* the ``description`` argument). :param tuple version: parse the message as in the specified TLS version, use negotiated version by default :param str description: name or comment attached to the node, it will be printed when :py:func:`str` or :py:func:`repr` is called on the node. """ super(ExpectNewSessionTicket, self).__init__( ContentType.handshake, HandshakeType.new_session_ticket) self.description = description self.version = version
[docs] def process(self, state, msg): """Parse, verify and process the message.""" assert msg.contentType == ContentType.handshake msg_bytes = msg.write() parser = Parser(msg_bytes) hs_type = parser.get(1) assert hs_type == HandshakeType.new_session_ticket if self.version is None: self.version = state.version if self.version < (3, 4): ticket = NewSessionTicket1_0().parse(parser) else: ticket = NewSessionTicket().parse(parser) ticket.time = time.time() state.session_tickets.append(ticket) if self.version < (3, 4): # in TLS 1.2 and earlier tickets are part of the Handshake, so # they need to be hashed state.handshake_messages.append(ticket) state.handshake_hashes.update(msg_bytes)
def __repr__(self): """Return human readable representation of object.""" return self._repr(['description'])
[docs] class ExpectHelloRequest(ExpectHandshake): """Processing of TLS handshake protocol hello request message.""" def __init__(self, description=None): """ Initialise object. .. note:: The ``description`` parameter MUST be specified as a keyword argument, i.e. read the definition as ``(self, *, description=None)`` (see PEP 3102). Otherwise the behaviour of this node is not guaranteed if new arguments are added to it (as they will be added *before* the ``description`` argument). :param str description: name or comment attached to the node, it will be printed when :py:func:`str` or :py:func:`repr` is called on the node. """ super(ExpectHelloRequest, self).__init__( ContentType.handshake, HandshakeType.hello_request) self.description = description
[docs] def process(self, state, msg): """Parse, verify and process the message.""" assert msg.contentType == ContentType.handshake parser = Parser(msg.write()) hs_type = parser.get(1) assert hs_type == HandshakeType.hello_request # check if it is well-formed HelloRequest().parse(parser)
def __repr__(self): """Return human readable representation of object.""" return self._repr(['description'])
[docs] class ExpectAlert(Expect): """Processing TLS Alert message""" def __init__(self, level=None, description=None): super(ExpectAlert, self).__init__(ContentType.alert) self.level = level self.description = description
[docs] def process(self, state, msg): assert msg.contentType == ContentType.alert parser = Parser(msg.write()) alert = Alert() alert.parse(parser) problem_desc = "" if self.level is not None and alert.level != self.level: problem_desc += "Alert level {0} != {1}".format(alert.level, self.level) if self.description is not None: # allow for multiple choice for description if not isinstance(self.description, Iterable): self.description = tuple([self.description]) if alert.description not in self.description: if problem_desc: problem_desc += ", " descriptions = ["\"{0}\"".format(AlertDescription.toStr(i)) for i in self.description] expected = ", ".join( itertools.chain((i for i in descriptions[:-2]), [" or ".join(i for i in descriptions[-2:])] )) received = AlertDescription.toStr(alert.description) problem_desc += ("Expected alert description {0} does not " "match received \"{1}\"" .format(expected, received)) if problem_desc: raise AssertionError(problem_desc)
def __repr__(self): """Return human readable representation of object.""" return self._repr(["level", "description"])
[docs] class ExpectSSL2Alert(ExpectHandshake): """Processing of SSLv2 Handshake protocol alert messages""" def __init__(self, error=None): super(ExpectSSL2Alert, self).__init__(ContentType.handshake, SSL2HandshakeType.error) self.error = error
[docs] def process(self, state, msg): """Analyse the error message""" assert msg.contentType == ContentType.handshake parser = Parser(msg.write()) hs_type = parser.get(1) assert hs_type == SSL2HandshakeType.error if self.error is not None: assert self.error == parser.get(2)
[docs] class ExpectApplicationData(Expect): """Processing Application Data message""" def __init__(self, data=None, size=None, output=None, description=None): super(ExpectApplicationData, self).\ __init__(ContentType.application_data) self.data = data self.size = size self.output = output self.description = description def __str__(self): """Return human readable representation of the object.""" return self._repr(['data', 'size', 'description'])
[docs] def process(self, state, msg): assert msg.contentType == ContentType.application_data data = msg.write() if self.data: assert self.data == data if self.size and len(data) != self.size: raise AssertionError("ApplicationData of unexpected size: {0}, " "expected: {1}".format(len(data), self.size)) if self.output: self.output.write("ExpectApplicationData received payload:\n") self.output.write(repr(data)) self.output.write("ExpectApplicationData end of payload.\n")
[docs] class ExpectHeartbeat(ExpectMessage): """Processing of heartbeat messages.""" def __init__(self, message_type=HeartbeatMessageType.heartbeat_response, payload=None, padding_size=None): """ Set up waiting for a heartbeat message. :type message_type: int :param message_type: Type of heartbeat messages to wait for, see `~tlslite.constants.HeartbeatMessageType` for defined types :type payload: bytes-like :param payload: literal value of padding to expect, if set to ``None``, any payload will be accepted :type padding_size: int :param padding_size: exact length of padding that will be expected, if set to ``None``, any padding length will be accepted """ super(ExpectHeartbeat, self).\ __init__(ContentType.heartbeat) self.message_type = message_type self.payload = payload self.padding_size = padding_size
[docs] def process(self, state, msg): """Check if the ``msg`` meets the requirements for the message.""" assert msg.contentType == ContentType.heartbeat parser = Parser(msg.write()) heartbeat = Heartbeat().parse(parser) self._cmp_eq(self.message_type, heartbeat.message_type, HeartbeatMessageType, "Unexpected heartbeat message type. Expected: {0}, " "received: {1}.") self._cmp_eq(self.payload, heartbeat.payload, f_str="Unexpected payload in Heartbeat message " "received. Expected: {0!r}, received: {1!r}") if self.padding_size is None: assert len(heartbeat.padding) >= 16 else: if len(heartbeat.padding) != self.padding_size: raise AssertionError( "Server sent unexpected size of padding " "in heartbeat message. Expected: {0}, " "received: {1}".format(self.padding_size, len(heartbeat.padding)))
[docs] class ExpectNoMessage(Expect): """ Virtual message signifying timeout on message listen. :ivar timeout: how long to wait for message before giving up, in seconds, can be float :vartype timeout: int or float """ def __init__(self, timeout=0.1): super(ExpectNoMessage, self).__init__(None) self.timeout = timeout
[docs] def process(self, state, msg): """Do nothing.""" pass
[docs] class ExpectClose(Expect): """Virtual message signifying closing of TCP connection""" def __init__(self): super(ExpectClose, self).__init__(None)
[docs] def process(self, state, msg): """Close our side""" state.msg_sock.sock.close()
[docs] class ExpectCertificateStatus(ExpectHandshake): """Processing of CertificateStatus message from RFC 6066.""" def __init__(self): super(ExpectCertificateStatus, self).__init__(ContentType.handshake, HandshakeType.certificate_status)
[docs] def process(self, state, msg): assert msg.contentType == ContentType.handshake parser = Parser(msg.write()) hs_type = parser.get(1) assert hs_type == HandshakeType.certificate_status cert_status = CertificateStatus().parse(parser) state.handshake_messages.append(cert_status) state.handshake_hashes.update(msg.write())
[docs] class ExpectKeyUpdate(ExpectHandshake): """Processing of post-handshake KeyUpdate message from RFC 8446""" def __init__(self, message_type=None): """ Initialize object. :type message_type: int :param message_type: type of KeyUpdate msg, either update_not_requested or update_requested """ super(ExpectKeyUpdate, self).__init__( ContentType.handshake, HandshakeType.key_update) self.message_type = message_type
[docs] def process(self, state, msg): """ Parse, verify and process the message. :type state: ConnectionState :type msg: Message """ assert msg.contentType == self.content_type parser = Parser(msg.write()) hs_type = parser.get(1) assert hs_type == self.handshake_type keyupdate = KeyUpdate().parse(parser) assert keyupdate.message_type == self.message_type _, sr_app_secret = state.msg_sock.\ calcTLS1_3KeyUpdate_sender( state.cipher, state.key['client application traffic secret'], state.key['server application traffic secret']) state.key['server application traffic secret'] = sr_app_secret