Module sshkey_tools.utils

Utilities for handling keys and certificates

Expand source code
"""
Utilities for handling keys and certificates
"""
import hashlib as hl
import sys
import datetime

from base64 import b64encode
from random import randint
from secrets import randbits
from typing import Dict, List, Union
from uuid import uuid4

from pytimeparse2 import parse as time_parse


NoneType = type(None)


def ensure_string(
    obj: Union[str, bytes, list, tuple, set, dict, NoneType],
    encoding: str = "utf-8",
    required: bool = False,
) -> Union[str, List[str], Dict[str, str], NoneType]:
    """Ensure the provided value is or contains a string/strings

    Args:
        obj (_type_): The object to process
        encoding (str, optional): The encoding of the provided strings. Defaults to 'utf-8'.

    Returns:
        Union[str, List[str], Dict[str, str]]: Returns a string, list of strings or
                                               dictionary with strings
    """
    if (obj is None and not required) or isinstance(obj, str):
        return obj
    if isinstance(obj, bytes):
        return obj.decode(encoding)
    if isinstance(obj, (list, tuple, set)):
        return [ensure_string(o, encoding) for o in obj]
    if isinstance(obj, dict):
        return {
            ensure_string(k, encoding): ensure_string(v, encoding)
            for k, v in obj.items()
        }

    raise TypeError(
        f"Expected one of (str, bytes, list, tuple, dict, set), got {type(obj).__name__}."
    )


def ensure_bytestring(
    obj: Union[str, bytes, list, tuple, set, dict, NoneType],
    encoding: str = "utf-8",
    required: bool = None,
) -> Union[str, List[str], Dict[str, str], NoneType]:
    """Ensure the provided value is or contains a bytestring/bytestrings

    Args:
        obj (_type_): The object to process
        encoding (str, optional): The encoding of the provided bytestrings. Defaults to 'utf-8'.

    Returns:
        Union[str, List[str], Dict[str, str]]: Returns a bytestring, list of bytestrings or
                                               dictionary with bytestrings
    """
    if (obj is None and not required) or isinstance(obj, bytes):
        return obj
    if isinstance(obj, str):
        return obj.encode(encoding)
    if isinstance(obj, (list, tuple, set)):
        return [ensure_bytestring(o, encoding) for o in obj]
    if isinstance(obj, dict):
        return {
            ensure_bytestring(k, encoding): ensure_bytestring(v, encoding)
            for k, v in obj.items()
        }
    raise TypeError(
        f"Expected one of (str, bytes, list, tuple, dict, set), got {type(obj).__name__}."
    )


def concat_to_string(*strs, encoding: str = "utf-8") -> str:
    """Concatenates a list of strings or bytestrings to a single string.

    Args:
        encoding (str, optional): The encoding of the string/s. Defaults to 'utf-8'.
        *strs (List[str, bytes]): The strings to concatenate

    Returns:
        str: Concatenated string
    """
    return "".join(st if st is not None else "" for st in ensure_string(strs, encoding))


def concat_to_bytestring(*strs, encoding: str = "utf-8") -> bytes:
    """Concatenates a list of strings or bytestrings to a single bytestring.

    Args:
        encoding (str, optional): The encoding of the string/s. Defaults to 'utf-8'.
        *strs (List[str, bytes]): The strings to concatenate

    Returns:
        bytes: Concatenated bytestring
    """
    return b"".join(
        st if st is not None else b""
        for st in ensure_bytestring(strs, encoding=encoding)
    )


def random_keyid() -> str:
    """Generates a random Key ID

    Returns:
        str: Random keyid
    """
    return str(uuid4())


def random_serial() -> str:
    """Generates a random serial number

    Returns:
        int: Random serial
    """
    return randint(0, 2**64 - 1)


def long_to_bytes(
    source_int: int, force_length: int = None, byteorder: str = "big"
) -> bytes:
    """Converts a positive integer to a byte string conforming with the certificate format.
        Equivalent to paramiko.util.deflate_long()
    Args:
        source_int (int): Integer to convert
        force_length (int, optional): Pads the resulting bytestring if shorter. Defaults to None.
        byteorder (str, optional): Byte order. Defaults to 'big'.

    Returns:
        str: Byte string representing the chosen long integer
    """
    if source_int < 0:
        raise ValueError(
            "You can only convert positive long integers to bytes with this method"
        )

    if not isinstance(source_int, int):
        raise TypeError(f"Expected integer, got {type(source_int).__name__}.")

    length = (source_int.bit_length() // 8 + 1) if not force_length else force_length
    return source_int.to_bytes(length, byteorder)


def bytes_to_long(source_bytes: bytes, byteorder: str = "big") -> int:
    """The opposite of long_to_bytes, converts a byte string to a long integer
       Equivalent to paramiko.util.inflate_long()
    Args:
        source_bytes (bytes): The byte string to convert
        byteorder (str, optional): Byte order. Defaults to 'big'.

    Returns:
        int: Long integer resulting from decoding the byte string
    """
    if not isinstance(source_bytes, bytes):
        raise TypeError(f"Expected bytes, got {type(source_bytes).__name__}.")

    return int.from_bytes(source_bytes, byteorder)


def generate_secure_nonce(length: int = 128):
    """Generates a secure random nonce of the specified length.
        Mainly important for ECDSA keys, but is used with all key/certificate types
        https://blog.trailofbits.com/2020/06/11/ecdsa-handle-with-care/
        https://datatracker.ietf.org/doc/html/rfc6979
    Args:
        length (int, optional): Length of the nonce. Defaults to 64.

    Returns:
        str: Nonce of the specified length
    """
    return str(randbits(length))


def md5_fingerprint(data: bytes, prefix: bool = True) -> str:
    """
    Returns an MD5 fingerprint of the given data.

    Args:
        data (bytes): The data to fingerprint
        prefix (bool, optional): Whether to prefix the fingerprint with MD5:

    Returns:
        str: The fingerprint (OpenSSH style MD5:xx:xx:xx...)
    """
    digest = hl.md5(data).hexdigest()
    return ("MD5:" if prefix else "") + ":".join(
        a + b for a, b in zip(digest[::2], digest[1::2])
    )


def sha256_fingerprint(data: bytes, prefix: bool = True) -> str:
    """
    Returns a SHA256 fingerprint of the given data.

    Args:
        data (bytes): The data to fingerprint
        prefix (bool, optional): Whether to prefix the fingerprint with SHA256:

    Returns:
        str: The fingerprint (OpenSSH style SHA256:xx:xx:xx...)
    """
    digest = hl.sha256(data).digest()
    return ("SHA256:" if prefix else "") + b64encode(digest).replace(b"=", b"").decode(
        "utf-8"
    )


def sha512_fingerprint(data: bytes, prefix: bool = True) -> str:
    """
    Returns a SHA512 fingerprint of the given data.

    Args:
        data (bytes): The data to fingerprint
        prefix (bool, optional): Whether to prefix the fingerprint with SHA512:

    Returns:
        str: The fingerprint (OpenSSH style SHA256:xx:xx:xx...)
    """
    digest = hl.sha512(data).digest()
    return ("SHA512:" if prefix else "") + b64encode(digest).replace(b"=", b"").decode(
        "utf-8"
    )


def nullsafe_getattr(obj, attr: str, default):
    """
    Null-safe getattr, ensuring the result is not None.
    If the result is None, the default value is returned instead.

    Args:
        obj: The object
        attr: The attribute to get
        default: The default value
    """
    att = getattr(obj, attr, default)
    if att is None:
        att = default
    
    return att

def join_dicts(*dicts) -> dict:
    """
    Joins two or more dictionaries together.
    In case of duplicate keys, the latest one wins.

    Returns:
        dict: Joined dictionary
    """
    py_version = sys.version_info[0:2]
    return_dict = {}

    if py_version[0] == 3 and py_version[1] > 9:
        for add_dict in dicts:
            return_dict = return_dict | add_dict

        return return_dict

    for add_dict in dicts:
        return_dict = {**return_dict, **add_dict}

    return return_dict


def str_to_time_delta(str_delta: str) -> datetime.timedelta:
    """Uses the package pytimeparse2 by wroberts/onegreyonewhite
        to convert a string into a timedelta object.
        Examples:
            - 32m
            - 2h32m
            - 3d2h32m
            - 1w3d2h32m
            - 1w 3d 2h 32m
            - 1 w 3 d 2 h 32 m
            - 4:13
            - 4:13:02
            - 4:13:02.266
            - 2:04:13:02.266
            - 2 days, 4:13:02 (uptime format)
            - 2 days, 4:13:02.266
            - 5hr34m56s
            - 5 hours, 34 minutes, 56 seconds
            - 5 hrs, 34 mins, 56 secs
            - 2 days, 5 hours, 34 minutes, 56 seconds
            - 1.2 m
            - 1.2 min
            - 1.2 mins
            - 1.2 minute
            - 1.2 minutes
            - 172 hours
            - 172 hr
            - 172 h
            - 172 hrs
            - 172 hour
            - 1.24 days
            - 5 d
            - 5 day
            - 5 days
            - 5.6 wk
            - 5.6 week
            - 5.6 weeks

    Args:
        str_delta (str): The time delta string to convert

    Returns:
        datetime.timedelta: The time delta object
    """
    try:
        parsed = time_parse(str_delta, as_timedelta=True, raise_exception=True)
        return parsed
    except Exception as ex:
        raise ValueError(
            f"Could not parse time delta string {str_delta} : {ex}"
        ) from ex

Functions

def bytes_to_long(source_bytes: bytes, byteorder: str = 'big') ‑> int

The opposite of long_to_bytes, converts a byte string to a long integer Equivalent to paramiko.util.inflate_long()

Args

source_bytes : bytes
The byte string to convert
byteorder : str, optional
Byte order. Defaults to 'big'.

Returns

int
Long integer resulting from decoding the byte string
Expand source code
def bytes_to_long(source_bytes: bytes, byteorder: str = "big") -> int:
    """The opposite of long_to_bytes, converts a byte string to a long integer
       Equivalent to paramiko.util.inflate_long()
    Args:
        source_bytes (bytes): The byte string to convert
        byteorder (str, optional): Byte order. Defaults to 'big'.

    Returns:
        int: Long integer resulting from decoding the byte string
    """
    if not isinstance(source_bytes, bytes):
        raise TypeError(f"Expected bytes, got {type(source_bytes).__name__}.")

    return int.from_bytes(source_bytes, byteorder)
def concat_to_bytestring(*strs, encoding: str = 'utf-8') ‑> bytes

Concatenates a list of strings or bytestrings to a single bytestring.

Args

encoding : str, optional
The encoding of the string/s. Defaults to 'utf-8'.
*strs : List[str, bytes]
The strings to concatenate

Returns

bytes
Concatenated bytestring
Expand source code
def concat_to_bytestring(*strs, encoding: str = "utf-8") -> bytes:
    """Concatenates a list of strings or bytestrings to a single bytestring.

    Args:
        encoding (str, optional): The encoding of the string/s. Defaults to 'utf-8'.
        *strs (List[str, bytes]): The strings to concatenate

    Returns:
        bytes: Concatenated bytestring
    """
    return b"".join(
        st if st is not None else b""
        for st in ensure_bytestring(strs, encoding=encoding)
    )
def concat_to_string(*strs, encoding: str = 'utf-8') ‑> str

Concatenates a list of strings or bytestrings to a single string.

Args

encoding : str, optional
The encoding of the string/s. Defaults to 'utf-8'.
*strs : List[str, bytes]
The strings to concatenate

Returns

str
Concatenated string
Expand source code
def concat_to_string(*strs, encoding: str = "utf-8") -> str:
    """Concatenates a list of strings or bytestrings to a single string.

    Args:
        encoding (str, optional): The encoding of the string/s. Defaults to 'utf-8'.
        *strs (List[str, bytes]): The strings to concatenate

    Returns:
        str: Concatenated string
    """
    return "".join(st if st is not None else "" for st in ensure_string(strs, encoding))
def ensure_bytestring(obj: Union[str, bytes, list, tuple, set, dict, ForwardRef(None)], encoding: str = 'utf-8', required: bool = None) ‑> Union[str, List[str], Dict[str, str], ForwardRef(None)]

Ensure the provided value is or contains a bytestring/bytestrings

Args

obj : _type_
The object to process
encoding : str, optional
The encoding of the provided bytestrings. Defaults to 'utf-8'.

Returns

Union[str, List[str], Dict[str, str]]
Returns a bytestring, list of bytestrings or dictionary with bytestrings
Expand source code
def ensure_bytestring(
    obj: Union[str, bytes, list, tuple, set, dict, NoneType],
    encoding: str = "utf-8",
    required: bool = None,
) -> Union[str, List[str], Dict[str, str], NoneType]:
    """Ensure the provided value is or contains a bytestring/bytestrings

    Args:
        obj (_type_): The object to process
        encoding (str, optional): The encoding of the provided bytestrings. Defaults to 'utf-8'.

    Returns:
        Union[str, List[str], Dict[str, str]]: Returns a bytestring, list of bytestrings or
                                               dictionary with bytestrings
    """
    if (obj is None and not required) or isinstance(obj, bytes):
        return obj
    if isinstance(obj, str):
        return obj.encode(encoding)
    if isinstance(obj, (list, tuple, set)):
        return [ensure_bytestring(o, encoding) for o in obj]
    if isinstance(obj, dict):
        return {
            ensure_bytestring(k, encoding): ensure_bytestring(v, encoding)
            for k, v in obj.items()
        }
    raise TypeError(
        f"Expected one of (str, bytes, list, tuple, dict, set), got {type(obj).__name__}."
    )
def ensure_string(obj: Union[str, bytes, list, tuple, set, dict, ForwardRef(None)], encoding: str = 'utf-8', required: bool = False) ‑> Union[str, List[str], Dict[str, str], ForwardRef(None)]

Ensure the provided value is or contains a string/strings

Args

obj : _type_
The object to process
encoding : str, optional
The encoding of the provided strings. Defaults to 'utf-8'.

Returns

Union[str, List[str], Dict[str, str]]
Returns a string, list of strings or dictionary with strings
Expand source code
def ensure_string(
    obj: Union[str, bytes, list, tuple, set, dict, NoneType],
    encoding: str = "utf-8",
    required: bool = False,
) -> Union[str, List[str], Dict[str, str], NoneType]:
    """Ensure the provided value is or contains a string/strings

    Args:
        obj (_type_): The object to process
        encoding (str, optional): The encoding of the provided strings. Defaults to 'utf-8'.

    Returns:
        Union[str, List[str], Dict[str, str]]: Returns a string, list of strings or
                                               dictionary with strings
    """
    if (obj is None and not required) or isinstance(obj, str):
        return obj
    if isinstance(obj, bytes):
        return obj.decode(encoding)
    if isinstance(obj, (list, tuple, set)):
        return [ensure_string(o, encoding) for o in obj]
    if isinstance(obj, dict):
        return {
            ensure_string(k, encoding): ensure_string(v, encoding)
            for k, v in obj.items()
        }

    raise TypeError(
        f"Expected one of (str, bytes, list, tuple, dict, set), got {type(obj).__name__}."
    )
def generate_secure_nonce(length: int = 128)

Generates a secure random nonce of the specified length. Mainly important for ECDSA keys, but is used with all key/certificate types https://blog.trailofbits.com/2020/06/11/ecdsa-handle-with-care/ https://datatracker.ietf.org/doc/html/rfc6979

Args

length : int, optional
Length of the nonce. Defaults to 64.

Returns

str
Nonce of the specified length
Expand source code
def generate_secure_nonce(length: int = 128):
    """Generates a secure random nonce of the specified length.
        Mainly important for ECDSA keys, but is used with all key/certificate types
        https://blog.trailofbits.com/2020/06/11/ecdsa-handle-with-care/
        https://datatracker.ietf.org/doc/html/rfc6979
    Args:
        length (int, optional): Length of the nonce. Defaults to 64.

    Returns:
        str: Nonce of the specified length
    """
    return str(randbits(length))
def join_dicts(*dicts) ‑> dict

Joins two or more dictionaries together. In case of duplicate keys, the latest one wins.

Returns

dict
Joined dictionary
Expand source code
def join_dicts(*dicts) -> dict:
    """
    Joins two or more dictionaries together.
    In case of duplicate keys, the latest one wins.

    Returns:
        dict: Joined dictionary
    """
    py_version = sys.version_info[0:2]
    return_dict = {}

    if py_version[0] == 3 and py_version[1] > 9:
        for add_dict in dicts:
            return_dict = return_dict | add_dict

        return return_dict

    for add_dict in dicts:
        return_dict = {**return_dict, **add_dict}

    return return_dict
def long_to_bytes(source_int: int, force_length: int = None, byteorder: str = 'big') ‑> bytes

Converts a positive integer to a byte string conforming with the certificate format. Equivalent to paramiko.util.deflate_long()

Args

source_int : int
Integer to convert
force_length : int, optional
Pads the resulting bytestring if shorter. Defaults to None.
byteorder : str, optional
Byte order. Defaults to 'big'.

Returns

str
Byte string representing the chosen long integer
Expand source code
def long_to_bytes(
    source_int: int, force_length: int = None, byteorder: str = "big"
) -> bytes:
    """Converts a positive integer to a byte string conforming with the certificate format.
        Equivalent to paramiko.util.deflate_long()
    Args:
        source_int (int): Integer to convert
        force_length (int, optional): Pads the resulting bytestring if shorter. Defaults to None.
        byteorder (str, optional): Byte order. Defaults to 'big'.

    Returns:
        str: Byte string representing the chosen long integer
    """
    if source_int < 0:
        raise ValueError(
            "You can only convert positive long integers to bytes with this method"
        )

    if not isinstance(source_int, int):
        raise TypeError(f"Expected integer, got {type(source_int).__name__}.")

    length = (source_int.bit_length() // 8 + 1) if not force_length else force_length
    return source_int.to_bytes(length, byteorder)
def md5_fingerprint(data: bytes, prefix: bool = True) ‑> str

Returns an MD5 fingerprint of the given data.

Args

data : bytes
The data to fingerprint
prefix : bool, optional
Whether to prefix the fingerprint with MD5:

Returns

str
The fingerprint (OpenSSH style MD5:xx:xx:xx…)
Expand source code
def md5_fingerprint(data: bytes, prefix: bool = True) -> str:
    """
    Returns an MD5 fingerprint of the given data.

    Args:
        data (bytes): The data to fingerprint
        prefix (bool, optional): Whether to prefix the fingerprint with MD5:

    Returns:
        str: The fingerprint (OpenSSH style MD5:xx:xx:xx...)
    """
    digest = hl.md5(data).hexdigest()
    return ("MD5:" if prefix else "") + ":".join(
        a + b for a, b in zip(digest[::2], digest[1::2])
    )
def nullsafe_getattr(obj, attr: str, default)

Null-safe getattr, ensuring the result is not None. If the result is None, the default value is returned instead.

Args

obj
The object
attr
The attribute to get
default
The default value
Expand source code
def nullsafe_getattr(obj, attr: str, default):
    """
    Null-safe getattr, ensuring the result is not None.
    If the result is None, the default value is returned instead.

    Args:
        obj: The object
        attr: The attribute to get
        default: The default value
    """
    att = getattr(obj, attr, default)
    if att is None:
        att = default
    
    return att
def random_keyid() ‑> str

Generates a random Key ID

Returns

str
Random keyid
Expand source code
def random_keyid() -> str:
    """Generates a random Key ID

    Returns:
        str: Random keyid
    """
    return str(uuid4())
def random_serial() ‑> str

Generates a random serial number

Returns

int
Random serial
Expand source code
def random_serial() -> str:
    """Generates a random serial number

    Returns:
        int: Random serial
    """
    return randint(0, 2**64 - 1)
def sha256_fingerprint(data: bytes, prefix: bool = True) ‑> str

Returns a SHA256 fingerprint of the given data.

Args

data : bytes
The data to fingerprint
prefix : bool, optional
Whether to prefix the fingerprint with SHA256:

Returns

str
The fingerprint (OpenSSH style SHA256:xx:xx:xx…)
Expand source code
def sha256_fingerprint(data: bytes, prefix: bool = True) -> str:
    """
    Returns a SHA256 fingerprint of the given data.

    Args:
        data (bytes): The data to fingerprint
        prefix (bool, optional): Whether to prefix the fingerprint with SHA256:

    Returns:
        str: The fingerprint (OpenSSH style SHA256:xx:xx:xx...)
    """
    digest = hl.sha256(data).digest()
    return ("SHA256:" if prefix else "") + b64encode(digest).replace(b"=", b"").decode(
        "utf-8"
    )
def sha512_fingerprint(data: bytes, prefix: bool = True) ‑> str

Returns a SHA512 fingerprint of the given data.

Args

data : bytes
The data to fingerprint
prefix : bool, optional
Whether to prefix the fingerprint with SHA512:

Returns

str
The fingerprint (OpenSSH style SHA256:xx:xx:xx…)
Expand source code
def sha512_fingerprint(data: bytes, prefix: bool = True) -> str:
    """
    Returns a SHA512 fingerprint of the given data.

    Args:
        data (bytes): The data to fingerprint
        prefix (bool, optional): Whether to prefix the fingerprint with SHA512:

    Returns:
        str: The fingerprint (OpenSSH style SHA256:xx:xx:xx...)
    """
    digest = hl.sha512(data).digest()
    return ("SHA512:" if prefix else "") + b64encode(digest).replace(b"=", b"").decode(
        "utf-8"
    )
def str_to_time_delta(str_delta: str) ‑> datetime.timedelta

Uses the package pytimeparse2 by wroberts/onegreyonewhite to convert a string into a timedelta object. Examples: - 32m - 2h32m - 3d2h32m - 1w3d2h32m - 1w 3d 2h 32m - 1 w 3 d 2 h 32 m - 4:13 - 4:13:02 - 4:13:02.266 - 2:04:13:02.266 - 2 days, 4:13:02 (uptime format) - 2 days, 4:13:02.266 - 5hr34m56s - 5 hours, 34 minutes, 56 seconds - 5 hrs, 34 mins, 56 secs - 2 days, 5 hours, 34 minutes, 56 seconds - 1.2 m - 1.2 min - 1.2 mins - 1.2 minute - 1.2 minutes - 172 hours - 172 hr - 172 h - 172 hrs - 172 hour - 1.24 days - 5 d - 5 day - 5 days - 5.6 wk - 5.6 week - 5.6 weeks

Args

str_delta : str
The time delta string to convert

Returns

datetime.timedelta
The time delta object
Expand source code
def str_to_time_delta(str_delta: str) -> datetime.timedelta:
    """Uses the package pytimeparse2 by wroberts/onegreyonewhite
        to convert a string into a timedelta object.
        Examples:
            - 32m
            - 2h32m
            - 3d2h32m
            - 1w3d2h32m
            - 1w 3d 2h 32m
            - 1 w 3 d 2 h 32 m
            - 4:13
            - 4:13:02
            - 4:13:02.266
            - 2:04:13:02.266
            - 2 days, 4:13:02 (uptime format)
            - 2 days, 4:13:02.266
            - 5hr34m56s
            - 5 hours, 34 minutes, 56 seconds
            - 5 hrs, 34 mins, 56 secs
            - 2 days, 5 hours, 34 minutes, 56 seconds
            - 1.2 m
            - 1.2 min
            - 1.2 mins
            - 1.2 minute
            - 1.2 minutes
            - 172 hours
            - 172 hr
            - 172 h
            - 172 hrs
            - 172 hour
            - 1.24 days
            - 5 d
            - 5 day
            - 5 days
            - 5.6 wk
            - 5.6 week
            - 5.6 weeks

    Args:
        str_delta (str): The time delta string to convert

    Returns:
        datetime.timedelta: The time delta object
    """
    try:
        parsed = time_parse(str_delta, as_timedelta=True, raise_exception=True)
        return parsed
    except Exception as ex:
        raise ValueError(
            f"Could not parse time delta string {str_delta} : {ex}"
        ) from ex