diff --git a/tests/topotests/lib/fe_client.py b/tests/topotests/lib/fe_client.py index 078df8cb33..4868c1018e 100755 --- a/tests/topotests/lib/fe_client.py +++ b/tests/topotests/lib/fe_client.py @@ -8,6 +8,7 @@ # # noqa: E501 # +"""A MGMTD front-end client.""" import argparse import logging import os @@ -34,7 +35,7 @@ except Exception as error: try: sys.path[0:0] = "." - import mgmt_pb2 + import mgmt_pb2 # pylint: disable=E0401 except Exception as error: logging.error("can't import proto definition modules %s", error) raise @@ -124,16 +125,22 @@ MSG_FORMAT_LYB = 3 def cstr(mdata): + """Convert a null-term byte array into a string, excluding the null terminator.""" assert mdata[-1] == 0 return mdata[:-1] class FEClientError(Exception): + """Base class for frontend client errors.""" + pass class PBMessageError(FEClientError): + """Exception for errors related to protobuf messages.""" + def __init__(self, msg, errstr): + """Initialize PBMessageError with message and error string.""" self.msg = msg # self.sess_id = mhdr[HDR_FIELD_SESS_ID] # self.req_id = mhdr[HDR_FIELD_REQ_ID] @@ -143,7 +150,10 @@ class PBMessageError(FEClientError): class NativeMessageError(FEClientError): + """Exception for errors related to native messages.""" + def __init__(self, mhdr, mfixed, mdata): + """Initialize NativeMessageError with message header, fixed fields, and data.""" self.mhdr = mhdr self.sess_id = mhdr[HDR_FIELD_SESS_ID] self.req_id = mhdr[HDR_FIELD_REQ_ID] @@ -173,6 +183,7 @@ def recv_wait(sock, size): def recv_msg(sock): + """Receive a message from the socket, ensuring it has a valid marker.""" marker = recv_wait(sock, 4) assert marker in (MGMT_MSG_MARKER_PROTOBUF, MGMT_MSG_MARKER_NATIVE) @@ -197,15 +208,18 @@ class Session: client_id = 1 def __init__(self, sock, use_protobuf): + """Initialize a session with the mgmtd server.""" self.sock = sock self.next_req_id = 1 if use_protobuf: + # Register the client req = mgmt_pb2.FeMessage() req.register_req.client_name = "test-client" self.send_pb_msg(req) logging.debug("Sent FeRegisterReq: %s", req) + # Create a session req = mgmt_pb2.FeMessage() req.session_req.create = 1 req.session_req.client_conn_id = Session.client_id @@ -219,11 +233,11 @@ class Session: assert reply.session_reply.success self.sess_id = reply.session_reply.session_id else: + # Establish a native session self.sess_id = 0 mdata, _ = self.get_native_msg_header(MSG_CODE_SESSION_REQ) mdata += struct.pack(MSG_SESSION_REQ_FMT) mdata += "test-client".encode("utf-8") + b"\x00" - self.send_native_msg(mdata) logging.debug("Sent native SESSION-REQ") @@ -236,6 +250,7 @@ class Session: self.sess_id = mhdr[HDR_FIELD_SESS_ID] def close(self, clean=True): + """Close the session.""" if clean: req = mgmt_pb2.FeMessage() req.session_req.create = 0 @@ -245,6 +260,7 @@ class Session: self.sock = None def get_next_req_id(self): + """Generate the next request ID for a new session.""" req_id = self.next_req_id self.next_req_id += 1 return req_id @@ -301,10 +317,11 @@ class Session: return mhdr, mfixed, mdata def send_native_msg(self, mdata): - """Send a native message.""" + """Send a native message to the mgmtd server.""" return send_msg(self.sock, MGMT_MSG_MARKER_NATIVE, mdata) def get_native_msg_header(self, msg_code): + """Generate a native message header for a given message code.""" req_id = self.get_next_req_id() hdata = struct.pack(MSG_HDR_FMT, msg_code, 0, self.sess_id, req_id) return hdata, req_id @@ -314,6 +331,19 @@ class Session: # ----------------------- def lock(self, lock=True, ds_id=mgmt_pb2.CANDIDATE_DS): + """Lock or unlock a datastore. + + Args: + lock (bool, optional): Whether to lock (True) or unlock (False) the + datastore. Defaults to True. + ds_id (int, optional): The datastore ID. Defaults to mgmt_pb2.CANDIDATE_DS. + + Returns: + None + + Raises: + AssertionError: If the lock request fails. + """ req = mgmt_pb2.FeMessage() req.lockds_req.session_id = self.sess_id req.lockds_req.req_id = self.get_next_req_id() @@ -327,7 +357,20 @@ class Session: assert reply.lockds_reply.success def get_data(self, query, data=True, config=False): - # Create the message + """Retrieve data from the mgmtd server based on an XPath query. + + Args: + query (str): The XPath query string. + data (bool, optional): Whether to retrieve state data. Defaults to True. + config (bool, optional): Whether to retrieve configuration data. + Defaults to False. + + Returns: + str: The retrieved data in JSON format. + + Raises: + AssertionError: If the response data is not properly formatted. + """ mdata, _ = self.get_native_msg_header(MSG_CODE_GET_DATA) flags = GET_DATA_FLAG_STATE if data else 0 flags |= GET_DATA_FLAG_CONFIG if config else 0 @@ -345,7 +388,12 @@ class Session: return result def add_notify_select(self, replace, notif_xpaths): - # Create the message + """Send a request to add notification subscriptions to the given XPaths. + + Args: + replace (bool): Whether to replace existing notification subscriptions. + notif_xpaths (list of str): List of XPaths to subscribe to notifications on. + """ mdata, _ = self.get_native_msg_header(MSG_CODE_NOTIFY_SELECT) mdata += struct.pack(MSG_NOTIFY_SELECT_FMT, replace) @@ -356,6 +404,18 @@ class Session: logging.debug("Sent NOTIFY_SELECT") def recv_notify(self, xpaths=None): + """Receive a notification message, optionally setting up XPath filters first. + + Args: + xpaths (list of str, optional): List of XPaths to filter notifications. + + Returns: + tuple: (result_type, operation, xpath, message data) + + Raises: + TimeoutError: If no notification is received within the timeout period. + Exception: If a non-notification message is received. + """ if xpaths: self.add_notify_select(True, xpaths) @@ -379,6 +439,7 @@ class Session: def __parse_args(): + """Parse command-line arguments for the mgmtd client.""" MPATH = "/var/run/frr/mgmtd_fe.sock" parser = argparse.ArgumentParser() parser.add_argument( @@ -416,6 +477,14 @@ def __parse_args(): def __server_connect(spath): + """Establish a connection to the mgmtd server over a Unix socket. + + Args: + spath (str): Path to the Unix domain socket. + + Returns: + socket: A connected Unix socket. + """ sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) logging.debug("Connecting to server on %s", spath) while ec := sock.connect_ex(str(spath)): @@ -428,6 +497,7 @@ def __server_connect(spath): def __main(): + """Process client commands and handle queries or notifications.""" args = __parse_args() sock = __server_connect(Path(args.server)) sess = Session(sock, use_protobuf=args.use_protobuf) @@ -473,6 +543,7 @@ def __main(): def main(): + """Entry point for the mgmtd client application.""" try: __main() except KeyboardInterrupt: