# Author: Hubert Kario, (c) 2015
# Released under Gnu GPL v2.0, see LICENSE file for details
"""Parsing and processing of received TLS messages"""
from __future__ import print_function
import itertools
from functools import partial
import sys
import time
import tlslite.utils.tlshashlib as hashlib
from tlslite.constants import ContentType, HandshakeType, CertificateType,\
HashAlgorithm, SignatureAlgorithm, ExtensionType,\
SSL2HandshakeType, CipherSuite, GroupName, AlertDescription, \
SignatureScheme, TLS_1_3_HRR, HeartbeatMode, \
TLS_1_1_DOWNGRADE_SENTINEL, TLS_1_2_DOWNGRADE_SENTINEL, \
HeartbeatMessageType, ClientCertificateType, CertificateStatusType
from tlslite.messages import ServerHello, Certificate, ServerHelloDone,\
ChangeCipherSpec, Finished, Alert, CertificateRequest, ServerHello2,\
ServerKeyExchange, ClientHello, ServerFinished, CertificateStatus, \
CertificateVerify, EncryptedExtensions, NewSessionTicket, Heartbeat,\
KeyUpdate, HelloRequest, NewSessionTicket1_0
from tlslite.extensions import TLSExtension, ALPNExtension
from tlslite.utils.codec import Parser, Writer
from tlslite.utils.compat import b2a_hex
from tlslite.utils.cryptomath import secureHMAC, derive_secret, \
HKDF_expand_label
from tlslite.mathtls import RFC7919_GROUPS, FFDHE_PARAMETERS, calc_key
from tlslite.keyexchange import KeyExchange, DHE_RSAKeyExchange, \
ECDHE_RSAKeyExchange
from tlslite.x509 import X509
from tlslite.x509certchain import X509CertChain
from tlslite.errors import TLSDecryptionFailed
from tlslite.handshakehashes import HandshakeHashes
from tlslite.handshakehelpers import HandshakeHelpers
from .handshake_helpers import calc_pending_states, kex_for_group, \
curve_name_to_hash_tls13
from .helpers import ECDSA_SIG_TLS1_3_ALL
from .tree import TreeNode
# pylint: disable=import-error,no-name-in-module
# pylint: disable=bad-option-value,deprecated-class
if sys.version_info >= (3, 3):
from collections.abc import Iterable
else:
from collections import Iterable
# pylint: enable=bad-option-value,deprecated-class
# pylint: enable=import-error,no-name-in-module
[docs]
class Expect(TreeNode):
"""Base class for objects handling message readers"""
def __init__(self, content_type):
"""Prepare the class for handling tree graph"""
super(Expect, self).__init__()
self.content_type = content_type
[docs]
def is_expect(self):
"""Flag to tell if the object is a message processor"""
return True
[docs]
def is_command(self):
"""Flag to tell that the object is a message processor"""
return False
[docs]
def is_generator(self):
"""Flag to tell that the object is not a message generator"""
return False
[docs]
def is_match(self, msg):
"""
Checks if the object can handle message
Note that the msg is a raw, unparsed message of indicated type that
requires calling write() to get a raw bytearray() representation of it
:type msg: tlslite.messages.Message
:param msg: raw message to check
"""
if msg.contentType == self.content_type:
return True
return False
[docs]
def process(self, state, msg):
"""
Process the message and update the state accordingly.
:type state: tlsfuzzer.runner.ConnectionState
:param state: current connection state, needs to be updated after
parsing the message by inheriting classes
:type msg: tlslite.messages.Message
:param msg: raw message to parse
"""
raise NotImplementedError("Subclasses need to implement this!")
[docs]
class ExpectMessage(Expect):
"""Common methods for handling TLS messages."""
[docs]
@staticmethod
def _cmp_eq(our, recv, field_type=None, f_str=None):
"""
Check if expected value matched received, if defined.
If our is not None, compare with recv. If they don't match, try
translating them with field_type.toStr() method and rise
AssertionError with message formatted with f_str. First parameter
to .format() will be expected value and the second one will be the
received one
"""
if our is None or our == recv:
return
if field_type:
expected = field_type.toStr(our)
received = field_type.toStr(recv)
else:
expected = our
received = recv
if not f_str:
f_str = "Expected: {0}, received: {1}"
raise AssertionError(f_str.format(expected, received))
[docs]
@classmethod
def _cmp_eq_or_in(cls, our, recv, field_type=None, f_str=None):
"""
Check if received value equals expected or is in expected list.
If our is a list or set, check if recv is in it.
If our is not None, check if it's equal to recv.
If they don't match or are not part of a set, try translating
them with field_type.toStr() method and raise AssertionError
formatted with f_str. First parameter to .format() will be
the expected value and the second one witll be the
received one.
"""
if our is None:
return
try:
if recv in our:
return
except TypeError:
return cls._cmp_eq(our, recv, field_type, f_str)
# doesn't match, so prepare the error message
if field_type:
expected = "({0})".format(", ".join(
field_type.toStr(i) for i in our))
received = field_type.toStr(recv)
else:
expected = our
received = recv
if not f_str:
f_str = "Received value ({1}) not in expected list: {0}"
raise AssertionError(f_str.format(expected, received))
[docs]
@staticmethod
def _cmp_eq_list(our, recv, field_type=None, f_str=None):
"""
Check if expected list of values matched received, if defined.
If our is not None, compare with recv. If they don't match, try
translating items in the lists with field_type.toStr() method and rise
AssertionError with message formatted with f_str. First parameter
to .format() will be list of expected values and the second one will be
the received one
"""
if our is None or our == recv:
return
if field_type:
expected = ", ".join(field_type.toStr(i) for i in our)
expected = "({0})".format(expected)
received = ", ".join(field_type.toStr(i) for i in recv)
received = "({0})".format(received)
else:
expected = repr(our)
received = repr(recv)
if not f_str:
f_str = "Expected: {0}, received: {1}"
raise AssertionError(f_str.format(expected, received))
[docs]
class ExpectHandshake(ExpectMessage):
"""Common methods for handling TLS Handshake protocol messages"""
def __init__(self, content_type, handshake_type):
"""
Set the type of message
:type content_type: int
:type handshake_type: int
"""
super(ExpectHandshake, self).__init__(content_type)
self.handshake_type = handshake_type
[docs]
def is_match(self, msg):
"""Check if message is a given type of handshake protocol message"""
if not super(ExpectHandshake, self).is_match(msg):
return False
if not msg.write(): # if message is empty
return False
hs_type = Parser(msg.write()).get(1)
if hs_type != self.handshake_type:
return False
return True
[docs]
def process(self, state, msg):
raise NotImplementedError("Subclass need to implement this!")
[docs]
def srv_ext_handler_ems(state, extension):
"""Process Extended Master Secret extension from server."""
if extension.extData:
raise AssertionError("Malformed EMS extension, data in payload")
state.extended_master_secret = True
[docs]
def srv_ext_handler_etm(state, extension):
"""Process Encrypt then MAC extension from server."""
if extension.extData:
raise AssertionError("Malformed EtM extension, data in payload")
state.encrypt_then_mac = True
[docs]
def srv_ext_handler_sni(state, extension):
"""Process the server_name extension from server."""
del state # kept for comatibility
if extension.extData:
raise AssertionError("Malformed SNI extenion, data in payload")
[docs]
def srv_ext_handler_renego(state, extension):
"""Process the renegotiation_info from server."""
if extension.renegotiated_connection != \
state.key['client_verify_data'] + state.key['server_verify_data']:
raise AssertionError("Invalid data in renegotiation_info")
[docs]
def srv_ext_handler_alpn(state, extension):
"""Process the ALPN extension from server."""
cln_hello = state.get_last_message_of_type(ClientHello)
cln_ext = cln_hello.getExtension(ExtensionType.alpn)
# the sent extension might have been provided with explicit encoding
cln_ext = ALPNExtension().parse(Parser(cln_ext.extData))
if not extension.protocol_names or len(extension.protocol_names) != 1:
raise AssertionError("Malformed ALPN extension")
if extension.protocol_names[0] not in cln_ext.protocol_names:
raise AssertionError("Server selected ALPN protocol we did not "
"advertise")
[docs]
def srv_ext_handler_ec_point(state, extension):
"""Process the ec_point_formats extension from server."""
del state
if extension.formats is None or not extension.formats:
raise AssertionError("Malformed ec_point_formats extension")
[docs]
def srv_ext_handler_npn(state, extension):
"""Process the NPN extension from server."""
del state
if extension.protocols is None or not extension.protocols:
raise AssertionError("Malformed NPN extension")
[docs]
def srv_ext_handler_session_ticket(state, extension):
"""Process the session_ticket extension from server."""
del state
if extension.ticket != b"":
raise AssertionError("Malformed session_ticket extension")
[docs]
def srv_ext_handler_key_share(state, extension):
"""Process the key_share extension from server."""
cln_hello = state.get_last_message_of_type(ClientHello)
cln_ext = cln_hello.getExtension(ExtensionType.key_share)
group_id = extension.server_share.group
cl_ext = next((i for i in cln_ext.client_shares if i.group == group_id),
None)
if cl_ext is None:
raise AssertionError("Server selected group we didn't advertise: {0}"
.format(GroupName.toStr(group_id)))
kex = kex_for_group(group_id, state.version)
state.key['ServerHello.extensions.key_share.key_exchange'] = \
extension.server_share.key_exchange
if not cl_ext.private:
raise ValueError("private value for key share of group {0} missing"
.format(GroupName.toStr(group_id)))
z = kex.calc_shared_key(cl_ext.private,
extension.server_share.key_exchange)
state.key['DH shared secret'] = z
[docs]
def hrr_ext_handler_key_share(state, extension):
"""Process the key_share extension in HRR message."""
cln_hello = state.get_last_message_of_type(ClientHello)
cln_ext = cln_hello.getExtension(ExtensionType.supported_groups)
group_id = extension.selected_group
if group_id not in cln_ext.groups:
raise AssertionError("Server selected group we didn't advertise: {0}"
.format(GroupName.toStr(group_id)))
[docs]
def hrr_ext_handler_cookie(state, extension):
"""Process the cookie extension in HRR message."""
del state
if not extension.cookie:
raise AssertionError("Server sent empty cookie extension")
[docs]
def srv_ext_handler_supp_vers(state, extension):
"""Process the supported_versions from server."""
cln_hello = state.get_last_message_of_type(ClientHello)
cln_ext = cln_hello.getExtension(ExtensionType.supported_versions)
vers = extension.version
if vers not in cln_ext.versions:
raise AssertionError("Server selected version we didn't advertise: {0}"
.format(vers))
state.version = vers
[docs]
def srv_ext_handler_supp_groups(state, extension):
"""Process the supported_groups from server."""
del state
if not extension.groups:
raise AssertionError("Server did not send any supported_groups")
[docs]
def srv_ext_handler_status_request(state, extension):
"""
Process the status_request extension from server.
TLS 1.2 ServerHello specific, in TLS 1.3 the extension resides in
Certificate message.
"""
del state
if extension.status_type is not None or \
extension.responder_id_list != [] or \
extension.request_extensions != bytearray():
raise AssertionError("Server did send non empty status_request "
"extension")
[docs]
def srv_ext_handler_heartbeat(state, extension):
"""Process the heartbeat extension from server."""
del state
if not extension.mode:
raise AssertionError("Empty mode in heartbeat extension.")
if extension.mode != HeartbeatMode.PEER_ALLOWED_TO_SEND and \
extension.mode != HeartbeatMode.PEER_NOT_ALLOWED_TO_SEND:
raise AssertionError("Invalid mode in heartbeat extension.")
[docs]
def _srv_ext_handler_psk(state, extension, psk_configs):
"""Process the pre_shared_key extension from server.
Since it needs the psk_configurations, it can't do it automatically
so it shouldn't be part of _srv_ext_handler.
"""
cln_hello = state.get_last_message_of_type(ClientHello)
cln_ext = cln_hello.getExtension(ExtensionType.pre_shared_key)
# the selection is 0-based
if extension.selected >= len(cln_ext.identities):
raise AssertionError("Server selected PSK we didn't send")
ident = cln_ext.identities[extension.selected].identity
if state.session_tickets:
nst = state.session_tickets[-1]
if nst.ticket == ident:
state.key['PSK secret'] = HandshakeHelpers.calc_res_binder_psk(
cln_ext.identities[extension.selected],
state.key['resumption master secret'],
[nst])
return
secret = next((i[1] for i in psk_configs if i[0] == ident), None)
if not secret:
raise ValueError("psk_configs are missing identity")
state.key['PSK secret'] = secret
[docs]
def gen_srv_ext_handler_psk(psk_configs=tuple()):
"""Creates a handler for pre_shared_key extension from the server."""
return partial(_srv_ext_handler_psk, psk_configs=psk_configs)
[docs]
def _srv_ext_handler_record_limit(state, extension, size=None):
"""Process record_size_limit extension from server."""
cln_hello = state.get_last_message_of_type(ClientHello)
cln_ext = cln_hello.getExtension(ExtensionType.record_size_limit)
assert extension.record_size_limit is not None
assert 64 <= extension.record_size_limit <= 2**14 + \
int(state.version > (3, 3))
if size and extension.record_size_limit != size:
raise AssertionError("Server sent unexpected size in extension, "
"expected size: {0}, received size: {1}"
.format(size, extension.record_size_limit))
if state.version <= (3, 3):
# in TLS 1.2 and earlier we need to delay that to processing of
# server CCS
state._peer_record_size_limit = extension.record_size_limit
state._our_record_size_limit = min(2**14, cln_ext.record_size_limit)
else:
# in TLS 1.3 we need to implement it right away (as the extension
# applies only to encrypted messages)
# the RecordLayer expects value that excludes content type
state.msg_sock.recv_record_limit = min(
2**14,
cln_ext.record_size_limit-1)
# this is just hint for padding callback
state.msg_sock.send_record_limit = min(
2**14,
extension.record_size_limit-1)
# this guides fragmentation
state.msg_sock.recordSize = state.msg_sock.send_record_limit
[docs]
def gen_srv_ext_handler_record_limit(size=None):
"""
Create a handler for record_size_limit_extension from the server.
Note that if the extension is actually negotiated, it will override
any `~SetMaxRecordSize()` before EncryptedExtensions in TLS 1.3 and
before ChangeCipherSpec in TLS 1.2 and earlier.
:param int size: expected value from server, None for any valid
"""
return partial(_srv_ext_handler_record_limit, size=size)
[docs]
def clnt_ext_handler_status_request(state, extension):
"""
Check status_request extension from initiating side.
To be used in ClientHello and CertificateRequest
"""
del state # kept for compatibility
if extension.status_type != CertificateStatusType.ocsp:
raise AssertionError(
"Unexpected status_type in status_request extension: {0}"
.format(CertificateStatusType.toStr(extension.status_type)))
if extension.responder_id_list is None \
or extension.request_extensions is None:
raise AssertionError(
"Malformed status_request extension")
[docs]
def clnt_ext_handler_sig_algs(state, extension):
"""
Check signature_algorithms or signature_algorithms_cert extension.
To be used in ClientHello and CertificateRequest.
"""
del state # kept for API compatibility
if not extension.sigalgs:
raise AssertionError(
"Empty or malformed {0} extension"
.format(ExtensionType.toStr(extension.extType)))
_srv_ext_handler = \
{ExtensionType.extended_master_secret: srv_ext_handler_ems,
ExtensionType.encrypt_then_mac: srv_ext_handler_etm,
ExtensionType.server_name: srv_ext_handler_sni,
ExtensionType.renegotiation_info: srv_ext_handler_renego,
ExtensionType.alpn: srv_ext_handler_alpn,
ExtensionType.session_ticket: srv_ext_handler_session_ticket,
ExtensionType.ec_point_formats: srv_ext_handler_ec_point,
ExtensionType.supports_npn: srv_ext_handler_npn,
ExtensionType.key_share: srv_ext_handler_key_share,
ExtensionType.supported_versions: srv_ext_handler_supp_vers,
ExtensionType.heartbeat: srv_ext_handler_heartbeat,
ExtensionType.record_size_limit: _srv_ext_handler_record_limit,
ExtensionType.status_request: srv_ext_handler_status_request}
_HRR_EXT_HANDLER = \
{ExtensionType.key_share: hrr_ext_handler_key_share,
ExtensionType.cookie: hrr_ext_handler_cookie}
_EE_EXT_HANDLER = \
{ExtensionType.server_name: srv_ext_handler_sni,
ExtensionType.alpn: srv_ext_handler_alpn,
ExtensionType.supported_groups: srv_ext_handler_supp_groups,
ExtensionType.heartbeat: srv_ext_handler_heartbeat,
ExtensionType.record_size_limit: _srv_ext_handler_record_limit}
_CR_EXT_HANDLER = \
{ExtensionType.status_request: clnt_ext_handler_status_request,
ExtensionType.signature_algorithms: clnt_ext_handler_sig_algs,
ExtensionType.signature_algorithms_cert: clnt_ext_handler_sig_algs}
[docs]
class _ExpectExtensionsMessage(ExpectHandshake):
"""
Common methods of messages that have a list of extensions.
Used in ServerHello, EncryptedExtensions and CertificateRequest (in
TLS 1.3)
"""
def __init__(self, content_type, msg_type, extensions):
super(_ExpectExtensionsMessage, self).__init__(
content_type, msg_type)
self.extensions = extensions
[docs]
def _compare_extensions(self, message):
"""
Verify that server provided extensions match exactly expected list.
"""
# if the list of extensions is present, make sure it matches exactly
# with what the server sent
if self.extensions and not message.extensions:
raise AssertionError("Server did not send any extensions")
if self.extensions is not None and message.extensions:
expected = set(self.extensions.keys())
got = set(i.extType for i in message.extensions)
if got != expected:
diff = expected.difference(got)
if diff:
raise AssertionError("Server did not send extension(s): "
"{0}".format(
", ".join((ExtensionType.toStr(i)
for i in diff))))
diff = got.difference(expected)
# we already checked if got != expected so diff here
# must be non-empty if the one checked above is
assert diff
raise AssertionError("Server sent unexpected extension(s):"
" {0}".format(
", ".join(ExtensionType.toStr(i)
for i in diff)))
[docs]
class ExpectServerHello(_ExpectExtensionsMessage):
"""
Parsing TLS Handshake protocol Server Hello messages.
Processing of the ServerHello message updates the record layer
to the version advertisied by the server.
Use :py:class:`~tlsfuzzer.messages.SetRecordVersion` to change it earlier
to send records with different versions.
.. note::
Receiving of the ServerHello in TLS 1.3 influences record layer
encryption. After the message is received, the
``client_handshake_traffic_secret`` and
``server_handshake_traffic_secret``
is derived and record layer is configured to expect encrypted records
on the *receiving* side.
:ivar str ~.description: identifier to print when processing of the
node fails
"""
def __init__(self, extensions=None, version=None, resume=False,
cipher=None, server_max_protocol=None, force_resume=False,
description=None):
"""
Initialize the object
:param dict extensions: extension objects to match the server sent
extensions or callbacks to process and verify them. None means use
automatic handlers that will verify the response against the extensions
sent in ClientHello. Empty dict means that the server is expected to
send no extensions. Order does not matter, but all extensions present
and only extensions present in the list must be sent by server. None
as the value of the relevant extension type can be used to select
autohandler for a given extension type.
:param tuple version: the literal version in the Server Hello message
(needs to be (3, 3) for TLS 1.3, use extensions to expect TLS 1.3
negotiation)
:param tuple server_max_protocol: the higher protocol version supported
by server. Used for testing downgrade signaling of servers.
:type cipher: int or set-like
:param int cipher: the id of the cipher that is expected to be
negotiated by server. Can also be a list or set (needs to support
``in``) for a set of allowed ciphers.
None (the default) means any valid cipher
(i.e. not SCSV or GREASE) sent in ClientHello can be selected by
server.
:type resume: boolean
:param resume: whether the session id should match the one from
current state - IOW, if the server hello should belong to a resumed
session. TLS 1.2 and earlier only. In TLS 1.3 resumption is handled
by providing handler for ``pre_shared_key`` extension.
:param boolean force_resume: assume that the session is getting resumed,
even if the sessionID is empty. Applicable to TLS 1.2 and earlier
only when using session tickets and not sending a sessionID.
"""
super(ExpectServerHello, self).__init__(ContentType.handshake,
HandshakeType.server_hello,
extensions)
self.cipher = cipher
self.version = version
self.resume = resume
self.srv_max_prot = server_max_protocol
self.force_resume = force_resume
self.description = description
def __str__(self):
"""Return human redable representation of the object."""
if self.description:
return "ExpectServerHello(description={0!r})"\
.format(self.description)
return "ExpectServerHello()"
[docs]
@staticmethod
def _get_autohandler(ext_id):
try:
return _srv_ext_handler[ext_id]
except KeyError:
raise AssertionError("No autohandler for "
"{0}"
.format(ExtensionType
.toStr(ext_id)))
[docs]
def _process_extensions(self, state, cln_hello, srv_hello):
"""Check if extensions are correct."""
# extensions allowed in TLS 1.3 ServerHello and HelloRetryRequest
# messages (as some need to be echoed by server in EncryptedExtensions
# and some in Certificate)
sh_supported = [ExtensionType.pre_shared_key,
ExtensionType.supported_versions,
ExtensionType.key_share]
hrr_supported = [ExtensionType.cookie,
ExtensionType.supported_versions,
ExtensionType.key_share]
for ext in srv_hello.extensions:
ext_id = ext.extType
if state.version > (3, 3) and \
((srv_hello.random != TLS_1_3_HRR and
ext_id not in sh_supported) or
(srv_hello.random == TLS_1_3_HRR and
ext_id not in hrr_supported)):
raise AssertionError("Server sent unallowed "
"extension of type {0}"
.format(ExtensionType
.toStr(ext_id)))
# in TLS 1.2 generally the server can reply to any client sent
# extension, and all of them end in ClientHello
cl_ext = cln_hello.getExtension(ext_id)
if ext_id == ExtensionType.renegotiation_info and \
CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV \
in cln_hello.cipher_suites:
cl_ext = True
if isinstance(self, ExpectHelloRetryRequest) and \
ext_id == ExtensionType.cookie:
cl_ext = True
if cl_ext is None:
raise AssertionError("Server sent unadvertised "
"extension of type {0}"
.format(ExtensionType
.toStr(ext_id)))
handler = None
if self.extensions:
handler = self.extensions[ext_id]
# use automatic handlers for some extensions
if handler is None:
handler = self._get_autohandler(ext_id)
if callable(handler):
handler(state, ext)
elif isinstance(handler, TLSExtension):
if not handler == ext:
raise AssertionError("Expected extension not "
"matched for type {0}, "
"received: {1}"
.format(ExtensionType
.toStr(ext_id),
ext))
else:
raise ValueError("Bad extension handler for id {0}"
.format(ExtensionType.toStr(ext_id)))
[docs]
def process(self, state, msg):
"""
Process the message and update state accordingly
:type state: ConnectionState
:param state: overall state of TLS connection
:type msg: Message
:param msg: TLS Message read from socket
"""
assert msg.contentType == ContentType.handshake
parser = Parser(msg.write())
hs_type = parser.get(1)
assert hs_type == HandshakeType.server_hello
srv_hello = ServerHello()
srv_hello.parse(parser)
# extract important info
state.server_random = srv_hello.random
cln_hello = state.get_last_message_of_type(ClientHello)
# check for session_id based session resumption
if self.resume:
assert state.session_id == srv_hello.session_id
if self.force_resume or ((state.session_id == srv_hello.session_id
or cln_hello.session_id == srv_hello.session_id) and
srv_hello.session_id != bytearray(0) and
self._extract_version(srv_hello) < (3, 4)):
# TLS 1.2 resumption, TLS 1.3 is based on PSKs
state.resuming = True
assert state.cipher == srv_hello.cipher_suite
assert state.version == self._extract_version(srv_hello)
state.session_id = srv_hello.session_id
self._cmp_eq(self.version, srv_hello.server_version,
f_str="Server selected unexpected protocol version. "
"Expected: {0}, received: {1}.")
self._cmp_eq_or_in(
self.cipher, srv_hello.cipher_suite,
f_str="Server selected unexpected ciphersuite. "
"Expected: {0}, received: {1}.")
# check if server sent cipher matches what we advertised in CH
if srv_hello.cipher_suite not in cln_hello.cipher_suites:
cipher = srv_hello.cipher_suite
if cipher in CipherSuite.ietfNames:
name = "{0} ({1:#06x})".format(CipherSuite.ietfNames[cipher],
cipher)
else:
name = "{0:#06x}".format(cipher)
raise AssertionError("Server responded with cipher we did"
" not advertise: {0}".format(name))
state.cipher = srv_hello.cipher_suite
state.version = self._extract_version(srv_hello)
# update the state of connection
state.msg_sock.version = state.version
state.msg_sock.tls13record = state.version > (3, 3)
self._check_against_hrr(state, srv_hello)
state.handshake_messages.append(srv_hello)
state.handshake_hashes.update(msg.write())
# Reset value of the session-wide settings
state.extended_master_secret = False
state.encrypt_then_mac = False
self._check_downgrade_protection(srv_hello)
self._compare_extensions(srv_hello)
if srv_hello.extensions:
self._process_extensions(state, cln_hello, srv_hello)
if state.version > (3, 3):
self._setup_tls13_handshake_keys(state)
return srv_hello
[docs]
@staticmethod
def _check_against_hrr(state, srv_hello):
if state.version < (3, 4):
return
hrr = state.get_last_message_of_type(ServerHello)
if not hrr or hrr.random != TLS_1_3_HRR:
# not an HRR, so HRR tests don't apply to it
return
if hrr.cipher_suite != srv_hello.cipher_suite:
raise AssertionError("Server picked different cipher suite than "
"it advertised in HelloRetryRequest")
hrr_version = hrr.getExtension(ExtensionType.supported_versions)
sh_version = srv_hello.getExtension(ExtensionType.supported_versions)
if hrr_version.version != sh_version.version:
raise AssertionError("Server picked different protocol version "
"than it advertised in HelloRetryRequest")
[docs]
def _setup_tls13_handshake_keys(self, state):
"""Set up the encryption keys for the TLS 1.3 handshake."""
del self
prf_name = state.prf_name
prf_size = state.prf_size
# Derive PSK secret
psk = state.key.setdefault('PSK secret', bytearray(prf_size))
# Derive TLS 1.3 early secret
secret = bytearray(prf_size)
secret = secureHMAC(secret, psk, prf_name)
state.key['early secret'] = secret
# Derive TLS 1.3 handshake secret
secret = derive_secret(secret, b'derived', None, prf_name)
dh_secret = state.key.setdefault('DH shared secret',
bytearray(prf_size))
secret = secureHMAC(secret, dh_secret, prf_name)
state.key['handshake secret'] = secret
# Derive TLS 1.3 traffic secrets
s_traffic_secret = derive_secret(secret, b's hs traffic',
state.handshake_hashes,
prf_name)
state.key['server handshake traffic secret'] = s_traffic_secret
c_traffic_secret = derive_secret(secret, b'c hs traffic',
state.handshake_hashes,
prf_name)
state.key['client handshake traffic secret'] = c_traffic_secret
state.msg_sock.calcTLS1_3PendingState(
state.cipher, c_traffic_secret, s_traffic_secret, None)
state.msg_sock.changeReadState()
[docs]
def _check_downgrade_protection(self, srv_hello):
"""
Verify that server provided downgrade protection as specified in
RFC 8446, Section 4.1.3
"""
# even if we don't know which version server supports, some values
# are obviously incorrect:
if (self._extract_version(srv_hello) > (3, 3) and
srv_hello.random[24:] == TLS_1_2_DOWNGRADE_SENTINEL) or \
(self._extract_version(srv_hello) > (3, 2) and
srv_hello.random[24:] == TLS_1_1_DOWNGRADE_SENTINEL):
raise AssertionError(
"Server set downgrade protection sentinel but shouldn't "
"have done that")
# as we're doing both TLS 1.2 tests and TLS 1.3 tests with `scripts/`
# we don't know when setting the sentinel is expected and when
# it is not as the negotiation might have ended up with TLS 1.2
# because that was the highest version we advertised
if self.srv_max_prot is None:
return
downgrade_value = None
if self.srv_max_prot > (3, 3) \
and self._extract_version(srv_hello) == (3, 3):
downgrade_value = TLS_1_2_DOWNGRADE_SENTINEL
elif self.srv_max_prot > (3, 2) \
and self._extract_version(srv_hello) < (3, 3):
downgrade_value = TLS_1_1_DOWNGRADE_SENTINEL
else:
if srv_hello.random[24:] == TLS_1_1_DOWNGRADE_SENTINEL or \
srv_hello.random[24:] == TLS_1_2_DOWNGRADE_SENTINEL:
raise AssertionError(
"Server set downgrade protection sentinel but shouldn't "
"have done that")
if downgrade_value is not None:
if srv_hello.random[24:] != downgrade_value:
raise AssertionError(
"Server failed to set downgrade protection sentinel in "
"ServerHello.random value")
[docs]
class ExpectHelloRetryRequest(ExpectServerHello):
"""Processing of the TLS 1.3 HelloRetryRequest message."""
def __init__(self, extensions=None, version=None, cipher=None):
super(ExpectHelloRetryRequest, self).__init__(
extensions, version, cipher)
self._ch_hh = None
self._msg = None
[docs]
def process(self, state, msg):
self._ch_hh = state.handshake_hashes.copy()
self._msg = msg
hrr = super(ExpectHelloRetryRequest, self).process(state, msg)
assert hrr.random == TLS_1_3_HRR
[docs]
@staticmethod
def _get_autohandler(ext_id):
try:
return _HRR_EXT_HANDLER[ext_id]
except KeyError:
try:
return _srv_ext_handler[ext_id]
except KeyError:
raise AssertionError("No autohandler for {0}".format(
ExtensionType.toStr(ext_id)))
[docs]
def _setup_tls13_handshake_keys(self, state):
"""Prepare handshake ciphers for the HRR handling"""
prf_name = state.prf_name
ch_hash = self._ch_hh.digest(prf_name)
new_hh = HandshakeHashes()
writer = Writer()
writer.add(HandshakeType.message_hash, 1)
writer.addVarSeq(ch_hash, 1, 3)
new_hh.update(writer.bytes)
new_hh.update(self._msg.write())
state.handshake_hashes = new_hh
[docs]
class ExpectServerHello2(ExpectHandshake):
"""Processing of SSLv2 Handshake Protocol SERVER-HELLO message"""
def __init__(self, version=None):
c_type = ContentType.handshake
h_type = SSL2HandshakeType.server_hello
super(ExpectServerHello2, self).__init__(c_type,
h_type)
self.version = version
[docs]
def process(self, state, msg):
"""
Process the message and update state accordingly
:type state: `~ConnectionState`
:param state: overall state of TLS connection
:type msg: Message
:param msg: TLS Message read from socket
"""
# the value is faked for SSLv2 protocol, but let's just check sanity
assert msg.contentType == ContentType.handshake
parser = Parser(msg.write())
hs_type = parser.get(1)
assert hs_type == SSL2HandshakeType.server_hello
server_hello = ServerHello2().parse(parser)
state.handshake_messages.append(server_hello)
state.handshake_hashes.update(msg.write())
self._cmp_eq(self.version, server_hello.server_version,
f_str="Server picked unexpected protocol version."
"Expected: {0}, received: {1}.")
if server_hello.session_id_hit:
state.resuming = True
state.session_id = server_hello.session_id
state.server_random = server_hello.session_id
state.version = server_hello.server_version
state.msg_sock.version = server_hello.server_version
# fake a certificate message so finding the server public key works
x509 = X509()
x509.parseBinary(server_hello.certificate)
cert_chain = X509CertChain([x509])
certificate = Certificate(CertificateType.x509)
certificate.create(cert_chain)
state.handshake_messages.append(certificate)
# fake message so don't update handshake hashes
[docs]
class ExpectCertificate(ExpectHandshake):
"""Processing TLS Handshake protocol Certificate messages"""
def __init__(self, cert_type=CertificateType.x509):
super(ExpectCertificate, self).__init__(ContentType.handshake,
HandshakeType.certificate)
self.cert_type = cert_type
self._old_cert = None
self._old_cert_bytes = None
[docs]
def process(self, state, msg):
"""
:type state: `~ConnectionState`
"""
assert msg.contentType == ContentType.handshake
msg_bytes = msg.write()
if self._old_cert_bytes is not None and \
msg_bytes == self._old_cert_bytes:
cert = self._old_cert
else:
parser = Parser(msg_bytes)
hs_type = parser.get(1)
assert hs_type == HandshakeType.certificate
cert = Certificate(self.cert_type, state.version)
cert.parse(parser)
self._old_cert_bytes = msg_bytes
self._old_cert = cert
state.handshake_messages.append(cert)
state.handshake_hashes.update(msg_bytes)
[docs]
class ExpectCertificateVerify(ExpectHandshake):
"""
Processing TLS Handshake protocol Certificate Verify messages.
:param tuple(int,int) version: Expected TLS version of the message. If not
provided will be taken from the state.
:param tuple(int,int) sig_alg: Expected value of the signature scheme
created by the server. If not provided it will be compared with signature
algorithm extension from client hello.
:param str hash_file: The file where hashes of the signature context will
be logged
:param str sig_file: The file where the signatures themselves will be
logged
"""
def __init__(
self, version=None, sig_alg=None, hash_file=None, sig_file=None
):
super(ExpectCertificateVerify, self).__init__(
ContentType.handshake,
HandshakeType.certificate_verify)
self.version = version
self.sig_alg = sig_alg
self.hash_file = hash_file
self.sig_file = sig_file
[docs]
def process(self, state, msg):
"""
:type state: `~ConnectionState`
"""
assert msg.contentType == ContentType.handshake
parser = Parser(msg.write())
hs_type = parser.get(1)
assert hs_type == HandshakeType.certificate_verify
if self.version is None:
self.version = state.version
cert_v = CertificateVerify(self.version)
cert_v.parse(parser)
if self.sig_alg:
assert self.sig_alg == cert_v.signatureAlgorithm
else:
c_hello = state.get_last_message_of_type(ClientHello)
ext = c_hello.getExtension(ExtensionType.signature_algorithms)
assert cert_v.signatureAlgorithm in ext.sigalgs
key_type = state.get_server_public_key().key_type
if key_type == "rsa-pss":
# in TLS 1.3 only RSA-PSS signatures are allowed
assert cert_v.signatureAlgorithm in (
SignatureScheme.rsa_pss_pss_sha256,
SignatureScheme.rsa_pss_pss_sha384,
SignatureScheme.rsa_pss_pss_sha512)
elif key_type == "rsa":
# in TLS 1.3 only RSA-PSS signatures are allowed
assert cert_v.signatureAlgorithm in (
SignatureScheme.rsa_pss_rsae_sha256,
SignatureScheme.rsa_pss_rsae_sha384,
SignatureScheme.rsa_pss_rsae_sha512)
elif key_type in ("Ed25519", "Ed448"):
assert cert_v.signatureAlgorithm in (
SignatureScheme.ed25519,
SignatureScheme.ed448)
if getattr(SignatureScheme, key_type.lower()) != \
cert_v.signatureAlgorithm:
raise AssertionError(
"Mismatched signature ({0}) for used key ({1})"
.format(
SignatureScheme.toStr(cert_v.signatureAlgorithm),
key_type))
else:
assert key_type == "ecdsa"
curve_name = state.get_server_public_key().curve_name
assert curve_name in ("NIST256p", "NIST384p", "NIST521p")
sigalg = cert_v.signatureAlgorithm
assert sigalg in ECDSA_SIG_TLS1_3_ALL
hash_name = curve_name_to_hash_tls13(curve_name)
# in TLS 1.3 the hash is bound to key curve
if sigalg != (getattr(HashAlgorithm, hash_name),
SignatureAlgorithm.ecdsa):
raise AssertionError(
"Invalid signature type for {1} key, "
"received: {0}"
.format(SignatureScheme.toStr(sigalg), curve_name))
salg = cert_v.signatureAlgorithm
if salg in (SignatureScheme.ed25519, SignatureScheme.ed448):
hash_name = "intrinsic"
padding = None
salt_len = None
elif salg[1] == SignatureAlgorithm.ecdsa:
hash_name = HashAlgorithm.toStr(salg[0])
padding = None
salt_len = None
else:
scheme = SignatureScheme.toRepr(salg)
hash_name = SignatureScheme.getHash(scheme)
padding = SignatureScheme.getPadding(scheme)
salt_len = getattr(hashlib, hash_name)().digest_size
transcript_hash = state.handshake_hashes.digest(state.prf_name)
sig_context = bytearray(b'\x20' * 64 +
b'TLS 1.3, server CertificateVerify' +
b'\x00') + transcript_hash
if not state.get_server_public_key().hashAndVerify(
cert_v.signature,
sig_context,
padding,
hash_name,
salt_len):
raise AssertionError("Signature verification failed")
if self.hash_file:
data = getattr(hashlib, hash_name)(sig_context).digest()
self.hash_file.write(data)
if self.sig_file:
self.sig_file.write(cert_v.signature)
state.handshake_messages.append(cert_v)
state.handshake_hashes.update(msg.write())
[docs]
class ExpectServerKeyExchange(ExpectHandshake):
"""Processing TLS Handshake protocol Server Key Exchange message"""
def __init__(self, version=None, cipher_suite=None, valid_sig_algs=None,
valid_groups=None, valid_params=None):
"""
Expect ServerKeyExchange message from server.
:param list(int) valid_groups: TLS group identifiers for groups that
server can use. In case the groups include identifiers between 256
and 512 (see RFC 7919), the node will also check that the server
selected FFDH parameters match the parameters specified in the RFC.
:param set(tuple(int,int)) valid_params: set of explicit expected
parameters used by the server, the first element of the tuple
is the expected generator and the second is the prime used for the
DH calculation. Applicable only to ciphersuites that use FFDHE
key exchange.
"""
msg_type = HandshakeType.server_key_exchange
super(ExpectServerKeyExchange, self).__init__(ContentType.handshake,
msg_type)
self.version = version
self.cipher_suite = cipher_suite
self.valid_sig_algs = valid_sig_algs
self.valid_groups = valid_groups
self.valid_params = valid_params
if self.valid_groups and self.valid_params:
raise ValueError("valid_groups and valid_params are exclusive")
[docs]
def _checkParams(self, server_key_exchange):
groups = []
if self.valid_groups and any(i in range(256, 512)
for i in self.valid_groups):
groups = [RFC7919_GROUPS[i - 256] for i in self.valid_groups
if i in range(256, 512)]
if self.valid_params:
groups = self.valid_params
server_params = (server_key_exchange.dh_g, server_key_exchange.dh_p)
if groups and server_params not in groups:
for name, params in FFDHE_PARAMETERS.items():
if server_params == params:
raise AssertionError(
"DH parameters not from valid set, "
"received: {0}".format(name))
raise AssertionError(
"DH parameters not from valid set, "
"received: g:{0}, p:{1}".format(
hex(server_params[0]),
hex(server_params[1])))
[docs]
def process(self, state, msg):
"""Process the Server Key Exchange message"""
assert msg.contentType == ContentType.handshake
parser = Parser(msg.write())
hs_type = parser.get(1)
assert hs_type == HandshakeType.server_key_exchange
if self.version is None:
self.version = state.version
if self.cipher_suite is None:
self.cipher_suite = state.cipher
valid_sig_algs = self.valid_sig_algs
valid_groups = self.valid_groups
server_key_exchange = ServerKeyExchange(self.cipher_suite,
self.version)
server_key_exchange.parse(parser)
client_random = state.client_random
server_random = state.server_random
public_key = state.get_server_public_key()
server_hello = state.get_last_message_of_type(ServerHello)
if server_hello is None:
server_hello = ServerHello
server_hello.server_version = state.version
if valid_sig_algs is None:
# if the value was unset in script, get the advertised value from
# Client Hello
client_hello = state.get_last_message_of_type(ClientHello)
if client_hello is not None:
sig_algs_ext = client_hello.getExtension(ExtensionType.
signature_algorithms)
if sig_algs_ext is not None:
valid_sig_algs = sig_algs_ext.sigalgs
if valid_sig_algs is None:
# no advertised means support for sha1 only
valid_sig_algs = [(HashAlgorithm.sha1, SignatureAlgorithm.rsa)]
if self.cipher_suite in CipherSuite.ecdheEcdsaSuites:
valid_sig_algs = [(HashAlgorithm.sha1,
SignatureAlgorithm.ecdsa)]
try:
KeyExchange.verifyServerKeyExchange(server_key_exchange,
public_key,
client_random,
server_random,
valid_sig_algs)
except TLSDecryptionFailed:
# very rarely validation of signature fails, print it so that
# we have a chance in debugging it
print("Bad signature: {0}"
.format(b2a_hex(server_key_exchange.signature)),
file=sys.stderr)
raise
if self.cipher_suite in CipherSuite.dhAllSuites:
self._checkParams(server_key_exchange)
state.key_exchange = DHE_RSAKeyExchange(self.cipher_suite,
clientHello=None,
serverHello=server_hello,
privateKey=None)
state.key['ServerKeyExchange.key_share'] = \
server_key_exchange.dh_Ys
state.key['ServerKeyExchange.dh_p'] = server_key_exchange.dh_p
elif self.cipher_suite in CipherSuite.ecdhAllSuites:
# extract valid groups from Client Hello
if valid_groups is None:
client_hello = state.get_last_message_of_type(ClientHello)
if client_hello is not None:
groups_ext = client_hello.getExtension(ExtensionType.
supported_groups)
if groups_ext is not None:
valid_groups = groups_ext.groups
if valid_groups is None:
# no advertised means support for all
valid_groups = GroupName.allEC
state.key_exchange = \
ECDHE_RSAKeyExchange(self.cipher_suite,
clientHello=None,
serverHello=server_hello,
privateKey=None,
acceptedCurves=valid_groups)
state.key['ServerKeyExchange.key_share'] = \
server_key_exchange.ecdh_Ys
else:
raise AssertionError("Unsupported cipher selected")
state.key['premaster_secret'] = state.key_exchange.\
processServerKeyExchange(public_key,
server_key_exchange)
state.handshake_messages.append(server_key_exchange)
state.handshake_hashes.update(msg.write())
# RFC8446 Section 4.2 says that implementation MUST reject extensions
# it recognises but which are not allowed in CertificateRequest
# check it against all defined in RFC8446
TLS_1_3_CR_FORBIDDEN = set((
ExtensionType.server_name,
1, # ExtensionType.max_fragment_length
ExtensionType.supported_groups,
14, # ExtensionType.use_srtp
ExtensionType.heartbeat,
ExtensionType.alpn,
19, # ExtensionType.client_certificate_type
20, # ExtensionType.server_certificate_type
21, # ExtensionType.padding,
ExtensionType.key_share,
ExtensionType.pre_shared_key,
ExtensionType.psk_key_exchange_modes,
ExtensionType.early_data,
ExtensionType.cookie,
ExtensionType.supported_versions,
49 # ExtensionType.post_handshake_auth
))
[docs]
class ExpectCertificateRequest(_ExpectExtensionsMessage):
"""Processing TLS Handshake protocol Certificate Request message."""
def __init__(self, sig_algs=None, cert_types=None,
sanity_check_cert_types=True, extensions=None, context=None):
"""
Set expected parameters for the CertificateRequest message.
:param sig_algs: a list of signature algorithms that we are expecting
from server. Needs to be in-order and complete. ``None`` to accept
any list from server. Applicable to TLS 1.2 and later only.
Do not use together with non-default ``extensions``.
:param cert_types: a list of client certificate types that we are
expecting from server. Needs to be in-order and complete.
``None`` to accept any list from server. Applicable to TLS 1.2 and
earlier only.
:param sanity_check_cert_types: set to ``False`` to disable
verification checking if every signature algorithm has a
corresponding client certificate type.
:param extensions: dictionary with extensions that need to be included
in the message. Set to ``None`` to accept any, set to empty dict to
expect no extensions. Usable in TLS 1.3 only.
"""
msg_type = HandshakeType.certificate_request
super(ExpectCertificateRequest, self).__init__(ContentType.handshake,
msg_type,
extensions)
self.sig_algs = sig_algs
self.cert_types = cert_types
self.context = context
self.sanity_check_cert_types = sanity_check_cert_types
if sig_algs is not None and extensions is not None:
raise ValueError("Can't set sig_algs and extensions at the same "
"time")
[docs]
@staticmethod
def _sanity_check_cert_types(cert_request):
"""Verify that the CertificateRequest is self-consistent."""
for sig_alg in cert_request.supported_signature_algs:
if sig_alg[1] in (SignatureAlgorithm.ecdsa,
SignatureAlgorithm.ed25519,
SignatureAlgorithm.ed448):
key_type = "ECDSA"
cert_type = "ecdsa_sign"
elif sig_alg[1] == SignatureAlgorithm.rsa:
key_type = "RSA"
cert_type = "rsa_sign"
elif sig_alg[1] == SignatureAlgorithm.dsa:
key_type = "DSA"
cert_type = "dss_sign"
else:
sig_scheme = SignatureScheme.toRepr(sig_alg)
key_type = SignatureScheme.getKeyType(sig_scheme)
assert key_type == "rsa", \
"Unsupported signature algorithm: {0}".format(sig_alg)
cert_type = "rsa_sign"
if getattr(ClientCertificateType, cert_type) \
not in cert_request.certificate_types:
raise AssertionError(
"CertificateRequest includes {1} signature algorithms "
"({0}) but does not include {2} client "
"certificate type".format(sig_alg, key_type, cert_type))
[docs]
@staticmethod
def _get_autohandler(ext_id):
try:
return _CR_EXT_HANDLER[ext_id]
except KeyError:
# handle future/GREASE extensions
return None
[docs]
def _process_extensions(self, state, msg):
for ext in msg.extensions:
ext_id = ext.extType
handler = None
if ext_id in TLS_1_3_CR_FORBIDDEN:
raise AssertionError(
"Server sent extension that is explicitly forbidden in "
"CertificateRequest messages: {0}".format(
ExtensionType.toStr(ext_id)))
if self.extensions:
handler = self.extensions[ext_id]
if handler is None:
handler = self._get_autohandler(ext_id)
if callable(handler):
handler(state, ext)
elif isinstance(handler, TLSExtension):
if not handler == ext:
raise AssertionError(
"Expected extension not matched for type {0}, "
"received: {1}".format(ExtensionType.toStr(ext_id),
ext))
elif handler is None:
# since server can send arbitrary extensions, we need to
# be able to process them, so if the self.extensions is unset
# we can just do nothing
pass
else:
raise ValueError("Bad extension handler for id {0}".format(
ExtensionType.toStr(ext_id)))
[docs]
def process(self, state, msg):
"""
Check received Certificate Request
:type state: ConnectionState
"""
assert msg.contentType == ContentType.handshake
parser = Parser(msg.write())
hs_type = parser.get(1)
assert hs_type == HandshakeType.certificate_request
cert_request = CertificateRequest(state.version)
cert_request.parse(parser)
self._cmp_eq_list(self.sig_algs, cert_request.supported_signature_algs,
SignatureScheme,
f_str="Unexpected signature algorithms. Got: {1}, "
"expected: {0}")
self._cmp_eq_list(self.cert_types, cert_request.certificate_types,
ClientCertificateType,
f_str="Unexpected client certificate types. Got: "
"{1}, expected: {0}")
if state.version == (3, 3) and self.sanity_check_cert_types:
# only in TLS 1.2 do the sig algs coexist with cert types
self._sanity_check_cert_types(cert_request)
if state.version >= (3, 4):
self._compare_extensions(cert_request)
self._process_extensions(state, cert_request)
if self.context is not None:
self.context.append(cert_request)
state.handshake_messages.append(cert_request)
state.handshake_hashes.update(msg.write())
[docs]
class ExpectServerHelloDone(ExpectHandshake):
"""Processing TLS Handshake protocol ServerHelloDone messages"""
def __init__(self):
super(ExpectServerHelloDone,
self).__init__(ContentType.handshake,
HandshakeType.server_hello_done)
[docs]
def process(self, state, msg):
"""
:type state: ConnectionState
:type msg: Message
"""
assert msg.contentType == ContentType.handshake
parser = Parser(msg.write())
hs_type = parser.get(1)
assert hs_type == HandshakeType.server_hello_done
srv_hello_done = ServerHelloDone()
srv_hello_done.parse(parser)
state.handshake_messages.append(srv_hello_done)
state.handshake_hashes.update(msg.write())
[docs]
class ExpectChangeCipherSpec(Expect):
"""
Processing TLS Change Cipher Spec messages.
.. note::
In SSLv3 up to TLS 1.2, the message modifies the state of record layer
to expect encrypted records *after* receiving this message.
In case of renegotiation, record layer will expect records encrypted
with the newly negotiated keys. In TLS 1.3 it has no effect on record
layer encryption.
"""
def __init__(self):
super(ExpectChangeCipherSpec,
self).__init__(ContentType.change_cipher_spec)
[docs]
def process(self, state, msg):
"""
:type state: ConnectionState
:type msg: Message
"""
assert msg.contentType == ContentType.change_cipher_spec
parser = Parser(msg.write())
ccs = ChangeCipherSpec().parse(parser)
assert ccs.type == 1
if state.version < (3, 4):
# in TLS 1.3 the CCS does not have any affect on encryption
if state.resuming:
state.msg_sock.encryptThenMAC = state.encrypt_then_mac
calc_pending_states(state)
state.msg_sock.changeReadState()
if state._our_record_size_limit:
state.msg_sock.recv_record_limit = state._our_record_size_limit
[docs]
class ExpectVerify(ExpectHandshake):
"""Processing of SSLv2 SERVER-VERIFY message"""
def __init__(self):
super(ExpectVerify, self).__init__(ContentType.handshake,
SSL2HandshakeType.server_verify)
[docs]
def process(self, state, msg):
"""Check if the VERIFY message has expected value"""
assert msg.contentType == ContentType.handshake
parser = Parser(msg.write())
msg_type = parser.get(1)
assert msg_type == SSL2HandshakeType.server_verify
[docs]
class ExpectFinished(ExpectHandshake):
"""
Processing TLS handshake protocol Finished message.
.. note::
In TLS 1.3 the message will modify record layer to start *sending*
records with encryption using the ``client_handshake_traffic_secret``
keys.
It will also modify the record layer to start expecting the records
to be encrypted with ``server_application_traffic_secret`` keys.
"""
def __init__(self, version=None, description=None):
"""
Initialize object.
.. note::
The ``description`` parameter MUST be specified
as a keyword argument, i.e. read the definition as
``(self, *, description=None)`` (see PEP 3102).
Otherwise the behaviour of this node is not guaranteed if new
arguments are added to it (as they will be added *before*
the ``description`` argument).
:param str description: name or comment attached to the node,
it will be printed when :py:func:`str` or :py:func:`repr` is
called on the node.
"""
if version in ((0, 2), (2, 0)):
super(ExpectFinished, self).__init__(ContentType.handshake,
SSL2HandshakeType.
server_finished)
else:
super(ExpectFinished, self).__init__(ContentType.handshake,
HandshakeType.finished)
self.version = version
self.description = description
[docs]
def process(self, state, msg):
"""
:type state: ConnectionState
:type msg: Message
"""
assert msg.contentType == ContentType.handshake
parser = Parser(msg.write())
hs_type = parser.get(1)
assert hs_type == self.handshake_type
if self.version is None:
self.version = state.version
if self.version in ((0, 2), (2, 0)):
finished = ServerFinished()
else:
finished = Finished(self.version, state.prf_size)
finished.parse(parser)
if self.version in ((0, 2), (2, 0)):
state.session_id = finished.verify_data
elif self.version <= (3, 3):
verify_expected = calc_key(state.version,
state.key['master_secret'],
state.cipher,
b'client finished' if not state.client
else b'server finished',
state.handshake_hashes,
output_length=12)
assert finished.verify_data == verify_expected
else: # TLS 1.3
finished_key = HKDF_expand_label(
state.key['server handshake traffic secret'],
b'finished',
b'',
state.prf_size,
state.prf_name)
transcript_hash = state.handshake_hashes.digest(state.prf_name)
verify_expected = secureHMAC(finished_key,
transcript_hash,
state.prf_name)
assert finished.verify_data == verify_expected
state.handshake_messages.append(finished)
state.key['server_verify_data'] = finished.verify_data
state.handshake_hashes.update(msg.write())
if self.version in ((0, 2), (2, 0)):
state.msg_sock.handshake_finished = True
if self.version > (3, 3):
# in TLS 1.3 ChangeCipherSpec is a no-op, so we need to attach
# the change for reading to some message that is always sent
state.msg_sock.changeWriteState()
# we now need to calculate application traffic keys to allow
# correct interpretation of the alerts regarding Certificate,
# CertificateVerify and Finished
# derive the master secret
secret = derive_secret(
state.key['handshake secret'], b'derived', None,
state.prf_name)
secret = secureHMAC(
secret, bytearray(state.prf_size), state.prf_name)
state.key['master secret'] = secret
# derive encryption keys
c_traff_sec = derive_secret(
secret, b'c ap traffic', state.handshake_hashes,
state.prf_name)
state.key['client application traffic secret'] = c_traff_sec
s_traff_sec = derive_secret(
secret, b's ap traffic', state.handshake_hashes,
state.prf_name)
state.key['server application traffic secret'] = s_traff_sec
# derive TLS exporter key
exp_ms = derive_secret(secret, b'exp master',
state.handshake_hashes,
state.prf_name)
state.key['exporter master secret'] = exp_ms
# set up the encryption keys for application data
state.msg_sock.calcTLS1_3PendingState(
state.cipher, c_traff_sec, s_traff_sec, None)
state.msg_sock.changeReadState()
def __repr__(self):
"""Return human readable representation of the object."""
return self._repr(['description'])
[docs]
class ExpectEncryptedExtensions(_ExpectExtensionsMessage):
"""Processing of the TLS handshake protocol Encrypted Extensions message"""
def __init__(self, extensions=None):
super(ExpectEncryptedExtensions, self).__init__(
ContentType.handshake,
HandshakeType.encrypted_extensions,
extensions)
[docs]
def _compare_extensions_in_ee(self, srv_exts, cln_hello):
"""
Verify that server provided extensions match exactly expected list.
"""
# check if received extensions match the set extensions
self._compare_extensions(srv_exts)
if self.extensions is None and srv_exts.extensions:
cln_exts = set(i.extType for i in cln_hello.extensions)
got = set(i.extType for i in srv_exts.extensions)
diff = got.difference(cln_exts)
if not got.issubset(cln_exts):
raise AssertionError("Server sent unexpected extension(s):"
" {0}".format(
", ".join(ExtensionType.toStr(i)
for i in diff)))
[docs]
@staticmethod
def _get_autohandler(ext_id):
try:
return _EE_EXT_HANDLER[ext_id]
except KeyError:
raise ValueError("No autohandler for "
"{0}"
.format(ExtensionType
.toStr(ext_id)))
[docs]
def _process_extensions(self, state, srv_exts):
"""Check if extensions are correct."""
# fix these constants, when the extensions are implemented
ee_supported = [ExtensionType.server_name,
1, # max_fragment_length - RFC 6066
ExtensionType.supported_groups,
14, # use_srtp - RFC 5764
ExtensionType.heartbeat, # RFC 6520
ExtensionType.alpn,
19, # client_certificate_type
# draft-ietf-tls-tls13-28 / RFC 7250
20, # server_certificate_type
# draft-ietf-tls-tls13-28 / RFC 7250
ExtensionType.record_size_limit, # RFC 8449
ExtensionType.early_data]
for ext in srv_exts.extensions:
ext_id = ext.extType
if ext_id not in ee_supported:
raise AssertionError("Server sent unsupported "
"extension of type {0}"
.format(ExtensionType
.toStr(ext_id)))
handler = None
if self.extensions:
handler = self.extensions[ext_id]
# use automatic handlers for some extensions
if handler is None:
handler = self._get_autohandler(ext_id)
if callable(handler):
handler(state, ext)
elif isinstance(handler, TLSExtension):
if not handler == ext:
raise AssertionError("Expected extension not "
"matched for type {0}, "
"received: {1}"
.format(ExtensionType
.toStr(ext_id),
ext))
else:
raise ValueError("Bad extension handler for id {0}"
.format(ExtensionType.toStr(ext_id)))
[docs]
def process(self, state, msg):
assert msg.contentType == ContentType.handshake
parser = Parser(msg.write())
hs_type = parser.get(1)
assert hs_type == self.handshake_type
srv_exts = EncryptedExtensions().parse(parser)
# get client_hello message with CH extensions
cln_hello = state.get_last_message_of_type(ClientHello)
self._compare_extensions_in_ee(srv_exts, cln_hello)
if srv_exts.extensions:
self._process_extensions(state, srv_exts)
state.handshake_messages.append(srv_exts)
state.handshake_hashes.update(msg.write())
[docs]
class ExpectNewSessionTicket(ExpectHandshake):
"""Processing TLS handshake protocol new session ticket message."""
def __init__(self, version=None, description=None):
"""
Initialise object.
.. note::
The ``description`` parameter MUST be specified
as a keyword argument, i.e. read the definition as
``(self, *, description=None)`` (see PEP 3102).
Otherwise the behaviour of this node is not guaranteed if new
arguments are added to it (as they will be added *before*
the ``description`` argument).
:param tuple version: parse the message as in the specified TLS
version, use negotiated version by default
:param str description: name or comment attached to the node,
it will be printed when :py:func:`str` or :py:func:`repr` is
called on the node.
"""
super(ExpectNewSessionTicket, self).__init__(
ContentType.handshake,
HandshakeType.new_session_ticket)
self.description = description
self.version = version
[docs]
def process(self, state, msg):
"""Parse, verify and process the message."""
assert msg.contentType == ContentType.handshake
msg_bytes = msg.write()
parser = Parser(msg_bytes)
hs_type = parser.get(1)
assert hs_type == HandshakeType.new_session_ticket
if self.version is None:
self.version = state.version
if self.version < (3, 4):
ticket = NewSessionTicket1_0().parse(parser)
else:
ticket = NewSessionTicket().parse(parser)
ticket.time = time.time()
state.session_tickets.append(ticket)
if self.version < (3, 4):
# in TLS 1.2 and earlier tickets are part of the Handshake, so
# they need to be hashed
state.handshake_messages.append(ticket)
state.handshake_hashes.update(msg_bytes)
def __repr__(self):
"""Return human readable representation of object."""
return self._repr(['description'])
[docs]
class ExpectHelloRequest(ExpectHandshake):
"""Processing of TLS handshake protocol hello request message."""
def __init__(self, description=None):
"""
Initialise object.
.. note::
The ``description`` parameter MUST be specified
as a keyword argument, i.e. read the definition as
``(self, *, description=None)`` (see PEP 3102).
Otherwise the behaviour of this node is not guaranteed if new
arguments are added to it (as they will be added *before*
the ``description`` argument).
:param str description: name or comment attached to the node,
it will be printed when :py:func:`str` or :py:func:`repr` is
called on the node.
"""
super(ExpectHelloRequest, self).__init__(
ContentType.handshake,
HandshakeType.hello_request)
self.description = description
[docs]
def process(self, state, msg):
"""Parse, verify and process the message."""
assert msg.contentType == ContentType.handshake
parser = Parser(msg.write())
hs_type = parser.get(1)
assert hs_type == HandshakeType.hello_request
# check if it is well-formed
HelloRequest().parse(parser)
def __repr__(self):
"""Return human readable representation of object."""
return self._repr(['description'])
[docs]
class ExpectAlert(Expect):
"""Processing TLS Alert message"""
def __init__(self, level=None, description=None):
super(ExpectAlert, self).__init__(ContentType.alert)
self.level = level
self.description = description
[docs]
def process(self, state, msg):
assert msg.contentType == ContentType.alert
parser = Parser(msg.write())
alert = Alert()
alert.parse(parser)
problem_desc = ""
if self.level is not None and alert.level != self.level:
problem_desc += "Alert level {0} != {1}".format(alert.level,
self.level)
if self.description is not None:
# allow for multiple choice for description
if not isinstance(self.description, Iterable):
self.description = tuple([self.description])
if alert.description not in self.description:
if problem_desc:
problem_desc += ", "
descriptions = ["\"{0}\"".format(AlertDescription.toStr(i))
for i in self.description]
expected = ", ".join(
itertools.chain((i for i in descriptions[:-2]),
[" or ".join(i for i in descriptions[-2:])]
))
received = AlertDescription.toStr(alert.description)
problem_desc += ("Expected alert description {0} does not "
"match received \"{1}\""
.format(expected, received))
if problem_desc:
raise AssertionError(problem_desc)
def __repr__(self):
"""Return human readable representation of object."""
return self._repr(["level", "description"])
[docs]
class ExpectSSL2Alert(ExpectHandshake):
"""Processing of SSLv2 Handshake protocol alert messages"""
def __init__(self, error=None):
super(ExpectSSL2Alert, self).__init__(ContentType.handshake,
SSL2HandshakeType.error)
self.error = error
[docs]
def process(self, state, msg):
"""Analyse the error message"""
assert msg.contentType == ContentType.handshake
parser = Parser(msg.write())
hs_type = parser.get(1)
assert hs_type == SSL2HandshakeType.error
if self.error is not None:
assert self.error == parser.get(2)
[docs]
class ExpectApplicationData(Expect):
"""Processing Application Data message"""
def __init__(self, data=None, size=None, output=None, description=None):
super(ExpectApplicationData, self).\
__init__(ContentType.application_data)
self.data = data
self.size = size
self.output = output
self.description = description
def __str__(self):
"""Return human readable representation of the object."""
return self._repr(['data', 'size', 'description'])
[docs]
def process(self, state, msg):
assert msg.contentType == ContentType.application_data
data = msg.write()
if self.data:
assert self.data == data
if self.size and len(data) != self.size:
raise AssertionError("ApplicationData of unexpected size: {0}, "
"expected: {1}".format(len(data), self.size))
if self.output:
self.output.write("ExpectApplicationData received payload:\n")
self.output.write(repr(data))
self.output.write("ExpectApplicationData end of payload.\n")
[docs]
class ExpectHeartbeat(ExpectMessage):
"""Processing of heartbeat messages."""
def __init__(self, message_type=HeartbeatMessageType.heartbeat_response,
payload=None, padding_size=None):
"""
Set up waiting for a heartbeat message.
:type message_type: int
:param message_type: Type of heartbeat messages to wait for, see
`~tlslite.constants.HeartbeatMessageType` for defined types
:type payload: bytes-like
:param payload: literal value of padding to expect, if set to ``None``,
any payload will be accepted
:type padding_size: int
:param padding_size: exact length of padding that will be expected,
if set to ``None``, any padding length will be accepted
"""
super(ExpectHeartbeat, self).\
__init__(ContentType.heartbeat)
self.message_type = message_type
self.payload = payload
self.padding_size = padding_size
[docs]
def process(self, state, msg):
"""Check if the ``msg`` meets the requirements for the message."""
assert msg.contentType == ContentType.heartbeat
parser = Parser(msg.write())
heartbeat = Heartbeat().parse(parser)
self._cmp_eq(self.message_type, heartbeat.message_type,
HeartbeatMessageType,
"Unexpected heartbeat message type. Expected: {0}, "
"received: {1}.")
self._cmp_eq(self.payload, heartbeat.payload,
f_str="Unexpected payload in Heartbeat message "
"received. Expected: {0!r}, received: {1!r}")
if self.padding_size is None:
assert len(heartbeat.padding) >= 16
else:
if len(heartbeat.padding) != self.padding_size:
raise AssertionError(
"Server sent unexpected size of padding "
"in heartbeat message. Expected: {0}, "
"received: {1}".format(self.padding_size,
len(heartbeat.padding)))
[docs]
class ExpectNoMessage(Expect):
"""
Virtual message signifying timeout on message listen.
:ivar timeout: how long to wait for message before giving up, in seconds,
can be float
:vartype timeout: int or float
"""
def __init__(self, timeout=0.1):
super(ExpectNoMessage, self).__init__(None)
self.timeout = timeout
[docs]
def process(self, state, msg):
"""Do nothing."""
pass
[docs]
class ExpectClose(Expect):
"""Virtual message signifying closing of TCP connection"""
def __init__(self):
super(ExpectClose, self).__init__(None)
[docs]
def process(self, state, msg):
"""Close our side"""
state.msg_sock.sock.close()
[docs]
class ExpectCertificateStatus(ExpectHandshake):
"""Processing of CertificateStatus message from RFC 6066."""
def __init__(self):
super(ExpectCertificateStatus,
self).__init__(ContentType.handshake,
HandshakeType.certificate_status)
[docs]
def process(self, state, msg):
assert msg.contentType == ContentType.handshake
parser = Parser(msg.write())
hs_type = parser.get(1)
assert hs_type == HandshakeType.certificate_status
cert_status = CertificateStatus().parse(parser)
state.handshake_messages.append(cert_status)
state.handshake_hashes.update(msg.write())
[docs]
class ExpectKeyUpdate(ExpectHandshake):
"""Processing of post-handshake KeyUpdate message from RFC 8446"""
def __init__(self, message_type=None):
"""
Initialize object.
:type message_type: int
:param message_type: type of KeyUpdate msg, either
update_not_requested or update_requested
"""
super(ExpectKeyUpdate, self).__init__(
ContentType.handshake,
HandshakeType.key_update)
self.message_type = message_type
[docs]
def process(self, state, msg):
"""
Parse, verify and process the message.
:type state: ConnectionState
:type msg: Message
"""
assert msg.contentType == self.content_type
parser = Parser(msg.write())
hs_type = parser.get(1)
assert hs_type == self.handshake_type
keyupdate = KeyUpdate().parse(parser)
assert keyupdate.message_type == self.message_type
_, sr_app_secret = state.msg_sock.\
calcTLS1_3KeyUpdate_sender(
state.cipher,
state.key['client application traffic secret'],
state.key['server application traffic secret'])
state.key['server application traffic secret'] = sr_app_secret