[PATCH 3/4] hw/usb/u2f-passthru: Clean up code

David Bouman posted 4 patches 5 months ago
[PATCH 3/4] hw/usb/u2f-passthru: Clean up code
Posted by David Bouman 5 months ago
Prepare for implementing the FIDO-U2F keepalive feature:

Represent all u2fhid frames using one coherent structure,
and make casts explicit.

Signed-off-by: David Bouman <dbouman03@gmail.com>
---
 hw/usb/u2f-passthru.c | 73 ++++++++++++++++++++++++++-----------------
 1 file changed, 44 insertions(+), 29 deletions(-)

diff --git a/hw/usb/u2f-passthru.c b/hw/usb/u2f-passthru.c
index 54062ab4d5..d0fb7b377c 100644
--- a/hw/usb/u2f-passthru.c
+++ b/hw/usb/u2f-passthru.c
@@ -87,30 +87,45 @@ struct U2FPassthruState {
 #define PACKET_CONT_HEADER_SIZE 5
 #define PACKET_CONT_DATA_SIZE (U2FHID_PACKET_SIZE - PACKET_CONT_HEADER_SIZE)
 
-struct packet_init {
+/* Frame definition */
+
+#define U2FHID_CMD_PING      0x81
+#define U2FHID_CMD_MSG       0x83
+#define U2FHID_CMD_INIT      0x86
+#define U2FHID_CMD_WINK      0x88
+#define U2FHID_CMD_CBOR      0x90
+#define U2FHID_CMD_CANCEL    0x91
+#define U2FHID_CMD_KEEPALIVE 0xbb
+#define U2FHID_ERROR         0xbf
+
+struct u2fhid_frame {
     uint32_t cid;
-    uint8_t cmd;
-    uint8_t bcnth;
-    uint8_t bcntl;
-    uint8_t data[PACKET_INIT_DATA_SIZE];
+    union {
+        uint8_t type;
+        struct {
+            uint8_t cmd;
+            uint8_t bcnth;
+            uint8_t bcntl;
+            uint8_t data[PACKET_INIT_DATA_SIZE];
+        } init;
+        struct {
+            uint8_t seq;
+            uint8_t data[PACKET_CONT_DATA_SIZE];
+        } cont;
+    };
 } QEMU_PACKED;
 
-static inline uint32_t packet_get_cid(const void *packet)
+static inline bool packet_is_init(struct u2fhid_frame *packet)
 {
-    return *((uint32_t *)packet);
-}
-
-static inline bool packet_is_init(const void *packet)
-{
-    return ((uint8_t *)packet)[4] & (1 << 7);
+    return !!(packet->type & (1 << 7));
 }
 
 static inline uint16_t packet_init_get_bcnt(
-        const struct packet_init *packet_init)
+        const struct u2fhid_frame *packet)
 {
     uint16_t bcnt = 0;
-    bcnt |= packet_init->bcnth << 8;
-    bcnt |= packet_init->bcntl;
+    bcnt |= (uint16_t)(packet->init.bcnth) << 8;
+    bcnt |= packet->init.bcntl;
 
     return bcnt;
 }
@@ -237,13 +252,13 @@ static void u2f_transaction_add(U2FPassthruState *key, uint32_t cid,
 static void u2f_passthru_read(void *opaque);
 
 static void u2f_transaction_start(U2FPassthruState *key,
-                                  const struct packet_init *packet_init)
+                                  const struct u2fhid_frame *packet_init)
 {
     int64_t time;
 
     /* Transaction */
     if (packet_init->cid == BROADCAST_CID) {
-        u2f_transaction_add(key, packet_init->cid, packet_init->data);
+        u2f_transaction_add(key, packet_init->cid, packet_init->init.data);
     } else {
         u2f_transaction_add(key, packet_init->cid, NULL);
     }
@@ -259,20 +274,19 @@ static void u2f_transaction_start(U2FPassthruState *key,
 }
 
 static void u2f_passthru_recv_from_host(U2FPassthruState *key,
-                                    const uint8_t packet[U2FHID_PACKET_SIZE])
+                                const uint8_t raw_packet[U2FHID_PACKET_SIZE])
 {
     struct transaction *transaction;
     uint32_t cid;
 
+    struct u2fhid_frame *packet = (void *)raw_packet;
     /* Retrieve transaction */
-    cid = packet_get_cid(packet);
+    cid = packet->cid;
     if (cid == BROADCAST_CID) {
-        struct packet_init *packet_init;
         if (!packet_is_init(packet)) {
             return;
         }
-        packet_init = (struct packet_init *)packet;
-        transaction = u2f_transaction_get_from_nonce(key, packet_init->data);
+        transaction = u2f_transaction_get_from_nonce(key, packet->init.data);
     } else {
         transaction = u2f_transaction_get(key, cid);
     }
@@ -283,13 +297,12 @@ static void u2f_passthru_recv_from_host(U2FPassthruState *key,
     }
 
     if (packet_is_init(packet)) {
-        struct packet_init *packet_init = (struct packet_init *)packet;
-        transaction->resp_bcnt = packet_init_get_bcnt(packet_init);
+        transaction->resp_bcnt = packet_init_get_bcnt(packet);
         transaction->resp_size = PACKET_INIT_DATA_SIZE;
 
-        if (packet_init->cid == BROADCAST_CID) {
+        if (packet->cid == BROADCAST_CID) {
             /* Nonce checking for legitimate response */
-            if (memcmp(transaction->nonce, packet_init->data, NONCE_SIZE)
+            if (memcmp(transaction->nonce, packet->init.data, NONCE_SIZE)
                 != 0) {
                 return;
             }
@@ -302,7 +315,7 @@ static void u2f_passthru_recv_from_host(U2FPassthruState *key,
     if (transaction->resp_size >= transaction->resp_bcnt) {
         u2f_transaction_close(key, cid);
     }
-    u2f_send_to_guest(&key->base, packet);
+    u2f_send_to_guest(&key->base, raw_packet);
 }
 
 static void u2f_passthru_read(void *opaque)
@@ -333,14 +346,16 @@ static void u2f_passthru_read(void *opaque)
 }
 
 static void u2f_passthru_recv_from_guest(U2FKeyState *base,
-                                    const uint8_t packet[U2FHID_PACKET_SIZE])
+                                const uint8_t raw_packet[U2FHID_PACKET_SIZE])
 {
     U2FPassthruState *key = PASSTHRU_U2F_KEY(base);
     uint8_t host_packet[U2FHID_PACKET_SIZE + 1];
     ssize_t written;
 
+    struct u2fhid_frame *packet = (void *)raw_packet;
+
     if (packet_is_init(packet)) {
-        u2f_transaction_start(key, (struct packet_init *)packet);
+        u2f_transaction_start(key, packet);
     }
 
     host_packet[0] = 0;
-- 
2.34.1