tests: add docstrings to frontend mgmtd client

Signed-off-by: Christian Hopps <chopps@labn.net>
This commit is contained in:
Christian Hopps 2025-02-21 18:47:15 +00:00
parent eb6f49ff1d
commit 84ab204ebe

View File

@ -8,6 +8,7 @@
#
# noqa: E501
#
"""A MGMTD front-end client."""
import argparse
import logging
import os
@ -34,7 +35,7 @@ except Exception as error:
try:
sys.path[0:0] = "."
import mgmt_pb2
import mgmt_pb2 # pylint: disable=E0401
except Exception as error:
logging.error("can't import proto definition modules %s", error)
raise
@ -124,16 +125,22 @@ MSG_FORMAT_LYB = 3
def cstr(mdata):
"""Convert a null-term byte array into a string, excluding the null terminator."""
assert mdata[-1] == 0
return mdata[:-1]
class FEClientError(Exception):
"""Base class for frontend client errors."""
pass
class PBMessageError(FEClientError):
"""Exception for errors related to protobuf messages."""
def __init__(self, msg, errstr):
"""Initialize PBMessageError with message and error string."""
self.msg = msg
# self.sess_id = mhdr[HDR_FIELD_SESS_ID]
# self.req_id = mhdr[HDR_FIELD_REQ_ID]
@ -143,7 +150,10 @@ class PBMessageError(FEClientError):
class NativeMessageError(FEClientError):
"""Exception for errors related to native messages."""
def __init__(self, mhdr, mfixed, mdata):
"""Initialize NativeMessageError with message header, fixed fields, and data."""
self.mhdr = mhdr
self.sess_id = mhdr[HDR_FIELD_SESS_ID]
self.req_id = mhdr[HDR_FIELD_REQ_ID]
@ -173,6 +183,7 @@ def recv_wait(sock, size):
def recv_msg(sock):
"""Receive a message from the socket, ensuring it has a valid marker."""
marker = recv_wait(sock, 4)
assert marker in (MGMT_MSG_MARKER_PROTOBUF, MGMT_MSG_MARKER_NATIVE)
@ -197,15 +208,18 @@ class Session:
client_id = 1
def __init__(self, sock, use_protobuf):
"""Initialize a session with the mgmtd server."""
self.sock = sock
self.next_req_id = 1
if use_protobuf:
# Register the client
req = mgmt_pb2.FeMessage()
req.register_req.client_name = "test-client"
self.send_pb_msg(req)
logging.debug("Sent FeRegisterReq: %s", req)
# Create a session
req = mgmt_pb2.FeMessage()
req.session_req.create = 1
req.session_req.client_conn_id = Session.client_id
@ -219,11 +233,11 @@ class Session:
assert reply.session_reply.success
self.sess_id = reply.session_reply.session_id
else:
# Establish a native session
self.sess_id = 0
mdata, _ = self.get_native_msg_header(MSG_CODE_SESSION_REQ)
mdata += struct.pack(MSG_SESSION_REQ_FMT)
mdata += "test-client".encode("utf-8") + b"\x00"
self.send_native_msg(mdata)
logging.debug("Sent native SESSION-REQ")
@ -236,6 +250,7 @@ class Session:
self.sess_id = mhdr[HDR_FIELD_SESS_ID]
def close(self, clean=True):
"""Close the session."""
if clean:
req = mgmt_pb2.FeMessage()
req.session_req.create = 0
@ -245,6 +260,7 @@ class Session:
self.sock = None
def get_next_req_id(self):
"""Generate the next request ID for a new session."""
req_id = self.next_req_id
self.next_req_id += 1
return req_id
@ -301,10 +317,11 @@ class Session:
return mhdr, mfixed, mdata
def send_native_msg(self, mdata):
"""Send a native message."""
"""Send a native message to the mgmtd server."""
return send_msg(self.sock, MGMT_MSG_MARKER_NATIVE, mdata)
def get_native_msg_header(self, msg_code):
"""Generate a native message header for a given message code."""
req_id = self.get_next_req_id()
hdata = struct.pack(MSG_HDR_FMT, msg_code, 0, self.sess_id, req_id)
return hdata, req_id
@ -314,6 +331,19 @@ class Session:
# -----------------------
def lock(self, lock=True, ds_id=mgmt_pb2.CANDIDATE_DS):
"""Lock or unlock a datastore.
Args:
lock (bool, optional): Whether to lock (True) or unlock (False) the
datastore. Defaults to True.
ds_id (int, optional): The datastore ID. Defaults to mgmt_pb2.CANDIDATE_DS.
Returns:
None
Raises:
AssertionError: If the lock request fails.
"""
req = mgmt_pb2.FeMessage()
req.lockds_req.session_id = self.sess_id
req.lockds_req.req_id = self.get_next_req_id()
@ -327,7 +357,20 @@ class Session:
assert reply.lockds_reply.success
def get_data(self, query, data=True, config=False):
# Create the message
"""Retrieve data from the mgmtd server based on an XPath query.
Args:
query (str): The XPath query string.
data (bool, optional): Whether to retrieve state data. Defaults to True.
config (bool, optional): Whether to retrieve configuration data.
Defaults to False.
Returns:
str: The retrieved data in JSON format.
Raises:
AssertionError: If the response data is not properly formatted.
"""
mdata, _ = self.get_native_msg_header(MSG_CODE_GET_DATA)
flags = GET_DATA_FLAG_STATE if data else 0
flags |= GET_DATA_FLAG_CONFIG if config else 0
@ -345,7 +388,12 @@ class Session:
return result
def add_notify_select(self, replace, notif_xpaths):
# Create the message
"""Send a request to add notification subscriptions to the given XPaths.
Args:
replace (bool): Whether to replace existing notification subscriptions.
notif_xpaths (list of str): List of XPaths to subscribe to notifications on.
"""
mdata, _ = self.get_native_msg_header(MSG_CODE_NOTIFY_SELECT)
mdata += struct.pack(MSG_NOTIFY_SELECT_FMT, replace)
@ -356,6 +404,18 @@ class Session:
logging.debug("Sent NOTIFY_SELECT")
def recv_notify(self, xpaths=None):
"""Receive a notification message, optionally setting up XPath filters first.
Args:
xpaths (list of str, optional): List of XPaths to filter notifications.
Returns:
tuple: (result_type, operation, xpath, message data)
Raises:
TimeoutError: If no notification is received within the timeout period.
Exception: If a non-notification message is received.
"""
if xpaths:
self.add_notify_select(True, xpaths)
@ -379,6 +439,7 @@ class Session:
def __parse_args():
"""Parse command-line arguments for the mgmtd client."""
MPATH = "/var/run/frr/mgmtd_fe.sock"
parser = argparse.ArgumentParser()
parser.add_argument(
@ -416,6 +477,14 @@ def __parse_args():
def __server_connect(spath):
"""Establish a connection to the mgmtd server over a Unix socket.
Args:
spath (str): Path to the Unix domain socket.
Returns:
socket: A connected Unix socket.
"""
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
logging.debug("Connecting to server on %s", spath)
while ec := sock.connect_ex(str(spath)):
@ -428,6 +497,7 @@ def __server_connect(spath):
def __main():
"""Process client commands and handle queries or notifications."""
args = __parse_args()
sock = __server_connect(Path(args.server))
sess = Session(sock, use_protobuf=args.use_protobuf)
@ -473,6 +543,7 @@ def __main():
def main():
"""Entry point for the mgmtd client application."""
try:
__main()
except KeyboardInterrupt: