diff --git a/server/tests/test-sasl.c b/server/tests/test-sasl.c index 7c1181dc..726b4029 100644 --- a/server/tests/test-sasl.c +++ b/server/tests/test-sasl.c @@ -41,11 +41,29 @@ typedef struct SPICE_ATTR_PACKED SpiceInitialMessage { } SpiceInitialMessage; #include +typedef enum { + FLAG_NONE = 0, + FLAG_START_OK = 1, + FLAG_LOW_SSF = 2, + FLAG_SERVER_NULL_START = 4, + FLAG_SERVER_NULL_STEP = 8, + FLAG_CLIENT_NULL_START = 16, + FLAG_CLIENT_NULL_STEP = 32, + FLAG_SERVER_BIG_START = 64, + FLAG_SERVER_BIG_STEP = 128, + FLAG_CLIENT_BIG_START = 256, + FLAG_CLIENT_BIG_STEP = 512, + FLAG_START_ERROR = 1024, + FLAG_STEP_ERROR = 2048, +} TestFlags; + static char *mechlist; +static char *big_data; static bool mechlist_called; static bool start_called; static bool step_called; static bool encode_called; +static unsigned test_flags; static SpiceCoreInterface *core; static SpiceServer *server; @@ -60,6 +78,22 @@ check_sasl_conn(sasl_conn_t *conn) g_assert_nonnull(conn); } +static void +get_step_out(const char **serverout, unsigned *serveroutlen, + const char *normal_data, unsigned null_flag, unsigned big_flag) +{ + if ((test_flags & big_flag) != 0) { + *serverout = big_data; + *serveroutlen = strlen(big_data); + } else if ((test_flags & null_flag) != 0) { + *serverout = NULL; + *serveroutlen = 0; + } else { + *serverout = normal_data; + *serveroutlen = strlen(normal_data); + } +} + int sasl_server_init(const sasl_callback_t *callbacks, const char *appname) { @@ -113,7 +147,8 @@ sasl_getprop(sasl_conn_t *conn, int propnum, g_assert_nonnull(pvalue); if (propnum == SASL_SSF) { - static const int val = 64; + static int val; + val = (test_flags & FLAG_LOW_SSF) ? 44 : 64; *pvalue = &val; } return SASL_OK; @@ -194,9 +229,11 @@ sasl_server_start(sasl_conn_t *conn, g_assert(!step_called); start_called = true; - *serverout = "foo"; - *serveroutlen = 3; - return SASL_OK; + get_step_out(serverout, serveroutlen, "foo", FLAG_SERVER_NULL_START, FLAG_SERVER_BIG_START); + if (test_flags & FLAG_START_ERROR) { + return SASL_FAIL; + } + return (test_flags & FLAG_START_OK) ? SASL_OK : SASL_CONTINUE; } int @@ -211,8 +248,10 @@ sasl_server_step(sasl_conn_t *conn, g_assert(start_called); step_called = true; - *serverout = "foo"; - *serveroutlen = 3; + get_step_out(serverout, serveroutlen, "foo", FLAG_SERVER_NULL_STEP, FLAG_SERVER_BIG_STEP); + if (test_flags & FLAG_STEP_ERROR) { + return SASL_FAIL; + } return SASL_OK; } @@ -244,11 +283,19 @@ reset_test(void) start_called = false; step_called = false; encode_called = false; + test_flags = FLAG_NONE; } static void start_test(void) { + g_assert_null(big_data); + big_data = g_malloc(1024 * 1024 + 10); + for (unsigned n = 0; n < 1024 * 1024 + 10; ++n) { + big_data[n] = ' ' + (n % 94); + } + big_data[1024 * 1024 + 5] = 0; + g_assert_null(server); initial_message.hdr.magic = SPICE_MAGIC; @@ -275,6 +322,9 @@ end_tests(void) g_free(mechlist); mechlist = NULL; + + g_free(big_data); + big_data = NULL; } static size_t @@ -343,24 +393,40 @@ idle_add(GSourceFunc func, void *arg) g_source_unref(source); } +typedef enum { + STEP_NONE, + STEP_READ_MECHLIST_LEN, + STEP_READ_MECHLIST, + STEP_WRITE_MECHNAME_LEN, + STEP_WRITE_MECHNAME, + STEP_WRITE_START_LEN, + STEP_WRITE_START, + STEP_WRITE_STEP_LEN, + STEP_WRITE_STEP, + STEP_NEVER, +} ClientEmulationSteps; + typedef struct { const char *mechname; int mechlen; bool success; + ClientEmulationSteps last_step; + unsigned flags; + int line; } TestData; static char long_mechname[128]; static TestData tests_data[] = { // these should just succeed #define TEST_SUCCESS(mech) \ - { mech, -1, true }, + { mech, -1, true, STEP_NEVER, FLAG_NONE, __LINE__ }, TEST_SUCCESS("ONE") TEST_SUCCESS("TWO") TEST_SUCCESS("THREE") // these test bad mech names #define TEST_BAD_NAME(mech, len) \ - { mech, len, false }, + { mech, len, false, STEP_NEVER, FLAG_NONE, __LINE__ }, TEST_BAD_NAME("ON", -1) TEST_BAD_NAME("NE", -1) TEST_BAD_NAME("THRE", -1) @@ -371,14 +437,41 @@ static TestData tests_data[] = { TEST_BAD_NAME(long_mechname, 100) TEST_BAD_NAME(long_mechname, 101) TEST_BAD_NAME("ONE,TWO", -1) + + // stop before filling everything +#define TEST_EARLY_STOP(step) \ + { "ONE", -1, false, step, FLAG_NONE, __LINE__}, + TEST_EARLY_STOP(STEP_READ_MECHLIST_LEN) + TEST_EARLY_STOP(STEP_READ_MECHLIST) + TEST_EARLY_STOP(STEP_WRITE_MECHNAME_LEN) + TEST_EARLY_STOP(STEP_WRITE_MECHNAME) + TEST_EARLY_STOP(STEP_WRITE_START_LEN) + TEST_EARLY_STOP(STEP_WRITE_START) + TEST_EARLY_STOP(STEP_WRITE_STEP_LEN) + +#define TEST_FLAGS(result, flags) \ + { "ONE", -1, result, STEP_NEVER, flags, __LINE__}, + TEST_FLAGS(false, FLAG_LOW_SSF) + TEST_FLAGS(false, FLAG_START_OK|FLAG_LOW_SSF) + TEST_FLAGS(true, FLAG_START_OK) + TEST_FLAGS(true, FLAG_SERVER_NULL_START) + TEST_FLAGS(true, FLAG_SERVER_NULL_STEP) + TEST_FLAGS(true, FLAG_CLIENT_NULL_START) + TEST_FLAGS(true, FLAG_CLIENT_NULL_STEP) + TEST_FLAGS(false, FLAG_SERVER_BIG_START) + TEST_FLAGS(false, FLAG_SERVER_BIG_STEP) + TEST_FLAGS(false, FLAG_CLIENT_BIG_START) + TEST_FLAGS(false, FLAG_CLIENT_BIG_STEP) + TEST_FLAGS(false, FLAG_START_ERROR) + TEST_FLAGS(false, FLAG_STEP_ERROR) }; -static void * -client_emulator(void *arg) +static void +client_emulator(int sock) { const TestData *data = &tests_data[test_num]; - int sock = GPOINTER_TO_INT(arg); +#define STOP_AT(step) if (data->last_step == STEP_ ## step) { return; } // send initial message write_all(sock, &initial_message, sizeof(initial_message)); @@ -402,23 +495,68 @@ client_emulator(void *arg) // mech SPICE_COMMON_CAP_AUTH_SASL) write_u32(sock, SPICE_COMMON_CAP_AUTH_SASL); + STOP_AT(NONE); + // sasl finally start, data starts from server (mech list) // uint32_t mechlen; read_u32(sock, &mechlen); + STOP_AT(READ_MECHLIST_LEN); + char buf[300]; g_assert_cmpint(mechlen, <=, sizeof(buf)); read_all(sock, buf, mechlen); + STOP_AT(READ_MECHLIST); // mech name write_u32(sock, data->mechlen); + STOP_AT(WRITE_MECHNAME_LEN); write_all(sock, data->mechname, data->mechlen); + STOP_AT(WRITE_MECHNAME); // first challenge - if (write_u32_err(sock, 5) == sizeof(uint32_t)) { - do_readwrite_all(sock, "START", 5, true); + const char *out; + unsigned outlen; + get_step_out(&out, &outlen, "START", FLAG_CLIENT_NULL_START, FLAG_CLIENT_BIG_START); + if (write_u32_err(sock, out ? outlen : 0) == sizeof(uint32_t)) { + STOP_AT(WRITE_START_LEN); + if (out) { + if (do_readwrite_all(sock, out, outlen, true) != outlen) { + return; + } + } + STOP_AT(WRITE_START); } + uint32_t datalen; + if (read_u32_err(sock, &datalen) != sizeof(datalen)) { + return; + } + if (datalen == GUINT32_FROM_LE(SPICE_MAGIC)) { + return; + } + g_assert_cmpint(datalen, <=, sizeof(buf)); + read_all(sock, buf, datalen); + + get_step_out(&out, &outlen, "STEP", FLAG_CLIENT_NULL_STEP, FLAG_CLIENT_BIG_STEP); + if (write_u32_err(sock, out ? outlen : 0) == sizeof(uint32_t)) { + STOP_AT(WRITE_STEP_LEN); + if (out) { + if (do_readwrite_all(sock, out, outlen, true) != outlen) { + return; + } + } + STOP_AT(WRITE_STEP); + } +} + +static void * +client_emulator_thread(void *arg) +{ + int sock = GPOINTER_TO_INT(arg); + + client_emulator(sock); + shutdown(sock, SHUT_RDWR); close(sock); @@ -434,8 +572,10 @@ setup_thread(void) if (data->mechlen < 0) { data->mechlen = strlen(data->mechname); } + test_flags = data->flags; int len = data->mechlen; - printf("\nRunning test %d ('%*.*s' %d)\n", test_num, len, len, data->mechname, len); + printf("\nRunning test %d ('%.*s' %d) line %d\n", + test_num, len, data->mechname, len, data->line); int sv[2]; g_assert_cmpint(socketpair(AF_LOCAL, SOCK_STREAM, 0, sv), ==, 0); @@ -443,7 +583,8 @@ setup_thread(void) g_assert(spice_server_add_client(server, sv[0], 0) == 0); pthread_t thread; - g_assert_cmpint(pthread_create(&thread, NULL, client_emulator, GINT_TO_POINTER(sv[1])), ==, 0); + g_assert_cmpint(pthread_create(&thread, NULL, client_emulator_thread, + GINT_TO_POINTER(sv[1])), ==, 0); return thread; }