C++, Rust & Sum Types

NB. This is an expansion of a conversation I had in the Rust sub-Reddit.

Introduction

Most programming languages support, to some degree at least, enumerations. By "enumeration" I mean what could more formally be termed a "sum type": a type that is the disjoint union of a set of other types. A variable of sum type A is either a B a C or a D (but that's it– it must be exactly one of those three). I've also seen these termed "variant types".

Rust supports this directly (example copied from the famous Rust book):

struct Ipv4Addr {
    // --snip--
}

struct Ipv6Addr {
    // --snip--
}

enum IpAddr {
    V4(Ipv4Addr),
    V6(Ipv6Addr),
}

Rust provides the match operator to allow you to "destructure" a variable of type IpAddr conveniently:

fn do_something(a: IpAddr) {
    match a {
        V4(v4addr) => {
            // `v4addr` is of type Ipv4Addr
            v4addr.do_something_specfic_to_v4();
        },
        V6(v6addr) => {
            // `v6addr` is of type Ipv6Addr
            v6addr.do_something_specfic_to_v6();
        }
        // if you ever add a new variant to the enumeration, and forget
        // to update this match statement, the compiler will catch it;
        // incomplete matches are a compile error in Rust
    }
}

Modern C++ provides std::variant (on which more below), but more traditionally one would define an enumerated type (i.e. a special case of a sum type in which each sub-type has only one value) & the sub-types separately:

class IpAddr {
    // --snip--
};

class Ipv4Addr: public IpAddr {
    // --snip--
}

class Ipv6Addr: public IpAddr {
    // --snip--
}

enum class IP_ADDR {
    V4,
    V6,
}

The connection is generally implicit & only made at run-time:

void do_something(IP_ADDR e, IpAddr &a) {
    switch (e) {
    case IP_ADDR::V4: {
      // a must be of type Ipv4Addr... we hope
      Ipv4Addr &a4 = dynamic_cast<IpV4Addr&>(a); // will throw if caller got it wrong!
      a4.do_something_specific_to_v4();
      break;
    }
    case IP_ADDR::V6: {
      // a must be of type Ipv6Addr... we hope
      Ipv6Addr &a6 = dynamic_cast<IpV6Addr&>(a); // will throw if caller got it wrong!
      a6.do_something_specific_to_v6();
      break;
    }
    // what if I add an enum value & forget to update this method?
    default:
      // throw... something
    }

There are two problems here:

  1. What happens if, in the future, you add a type to your enumeration & forget to update the method? In Rust, that will be a compiler error. In C++, you may be able to induce your compiler to warn you, but that's the best-case scenario.
  2. The connection between enumerated type & sub-type is enforced at compile-time in Rust, but only at run-time in C++; in the C++ example, if the caller provides the wrong sub-type for the given IP_ADDR value, all I can do is catch that at run-time & throw.

This sort of thing is why I've become such a big fan of Rust over the past year or two. Rust allows me to express intent at compile-time that in C++ I can only express at run-time. Recently, however, I encountered a situation that made me re-think my blanket belief that Rust was superior to C++ in its handling of sum types.

The Rust approach is great, until you've got too many cases. I've been recently been working with the llvm-ir crate. It represents LLVM non-terminal instructions, not unreasonably, as an enumeration… with fifty-four cases. Assume a match statement has a pair of braces for each variant, with say an average of three lines to handle each case; that's an average of five lines per variant, or 270 in total: not readable or maintainable in my opinion.

Note: I realize that modern C++ provides the std::variant type, which offers an explicit representation of a sum type. Since provisioning an std::variant with fifty-four types seems as unpleasant as the corresponding Rust match statement, I don'd address that here.

A Solution in C++

In C++ one could implement a more elegant solution, albeit still at the cost of deferring error detection to run-time from compile time.

Let's give ourselves a little enumeration (indicating the associated types in comments):

enum opcode {
  A /*(A_payload)*/,
  B /*(B_payload)*/,
  C /*(C_payload)*/
};

and let's give ourselves three trivial types to play the role of payloads:

struct A_payload {
  void do_a() const {
    std::cout << "You are manipulating an A." << std::endl;
  }
};
struct B_payload {
  void do_b() const {
    std::cout << "You are manipulating a B." << std::endl;
  }
};
struct C_payload {
  void do_c() const {
    std::cout << "You are manipulating a C." << std::endl;
  }
};

Since in this example the three types are unrelated (no common base class), we have to use type erasure to store a map of handlers indexed by enumeration value. But due to a little trick I picked up from Alexandrescu, we can still present a clean interface to our handlers:

struct dispatcher {
  // "register" a handler that takes OP & T here
  template <typename T, opcode OP, void (*callback)(opcode code, const T&)>
  void add() {
    struct loc
    {
      static void trampoline(opcode code, const std::any &p)
      {
        // Will catch a mismatch at runtime here
        const T &t = std::any_cast<const T&>(p);
        return callback(code, t);
      }
    };
    table_[OP] = &loc::trampoline;
  }
  // invoke it here-- your type will implicitly be converted to an `any`
  void
  operator()(opcode code, const std::any &p)
  {
    (*table_[code])(code, p);
  }
private:
  // this is my map from enum value to handler
  std::unordered_map<opcode, void (*)(opcode, const std::any&)> table_;
};

We store a map from opcode to pointer-to-function taking std::any. The trick is to use a member template to "wrap" the properly-typed callbacks with a function of std::any (that we can actually store in the lookup table) to spare the caller having to deal with recovering their types.

So if we give ourselves three trivial handlers:

void
handle_A(opcode /*code*/, const A_payload &x) {
  x.do_a();
}
void
handle_B(opcode /*code*/, const B_payload &x) {
  x.do_b();
}
void
handle_C(opcode /*code*/, const C_payload &x) {
  x.do_c();
}

we can set it all up like so:

// setup our dispatcher here...
dispatcher D;
// this is where we "register" the connection between enum
// values & payload type
D.add<A_payload, A, handle_A>();
D.add<B_payload, B, handle_B>();
D.add<C_payload, C, handle_C>();

So, where we have an instance of a sum type (E, X) where E is an opcode and X is the corresponding payload, we can invoke the correct handler like so:

D(E, X);

For instance:

// This will invoke the previously registered handlers:
D(A, A_payload()); // prints "You are manipulating an A."
D(B, B_payload()); // "You are manipulating a B."
D(C, C_payload()); // "You are manipulating a C."

// This will throw-- note that we're trying to handle 
// an enum value of `A` with the wrong payload. It will
// *compile*, but it will throw at run-time
D(A, C_payload());

To summarize, in C++ I may not be able to assert at compile-time the connection between enumeration value & payload type– I am reduced to checking it at run-time. However, that trade-off allows me to write handlers in terms of the concrete types on which they operate with no downcasting, destructuring, &c, I escape needing a giant switch statement, and I can still dispatch in O(1) time.

I can do that in Rust with a big match statement, and I suppose you could argue that my C++ implementation needs to register that association for every enum value at runtime, so what's the difference? Well, the registration code can be done once in some static initializer, whereas the match statement is going to sprawl all over my implementation code.

The Best I Could Come Up With in Rust Without Macros

My line of attack in Rust was to use Visitor instead of a hand-built dispatcher. To continue the example above, let us suppose we have the following sum type:

struct A_payload {
    // -- snip --
}
struct B_payload {
    // -- snip --
}
struct C_payload {
    // -- snip --
}
enum opcode {
    A(A_payload),
    B(B_payload),
    C(C_payload),
}

Let's setup our Visitor traits:

trait VisitableOpcode<T> {
    fn accept(&self, visitor: &mut dyn OpcodeVisitor<T>) -> Result<T>;
}

trait OpcodeVisitor<T> {
    fn visit_a(&mut self, instr: &opcode::A_payload) -> Result<T>;
    fn visit_b(&mut self, instr: &opcode::B_payload) -> Result<T>;
    fn visit_c(&mut self, instr: &opcode::C_payload) -> Result<T>;
}

and we'll implement VisitableOpcode for each variant:

impl VisitableOpcode<()> for A_payload {
    fn accept(&self, visitor: &mut dyn OpcodeVisitor<()>) -> Result<()> {
        visitor.visit_a(self)
    }
}
impl VisitableOpcode<()> for B_payload {
    fn accept(&self, visitor: &mut dyn OpcodeVisitor<()>) -> Result<()> {
        visitor.visit_b(self)
    }
}
impl VisitableOpcode<()> for C_payload {
    fn accept(&self, visitor: &mut dyn OpcodeVisitor<()>) -> Result<()> {
        visitor.visit_c(self)
    }
}

So now, given a member of opcode, all we need to do is somehow obtain a dyn OpcodeVisitor<()>:

fn visitable_opcode<'a>(
    instr: &'a Opcode,
) -> &'a (dyn VisitableOpcode<()> + 'a) {
    match instr {
        Opcode::Add(add) => Ok(add),
        Opcode::Load(sub) => Ok(sub),
        // ...
    }
}

So now, where we want to process opcodes, instead of a massive match statement, we can say:

let v = // something that implements OpcodeVisitor...
let op = // some element of opcode
visitable_opcode(op).accept(&mut v);

To summarize, my Rust implementation asserts at compile-time the connection between enumeration variant & payload type, my handlers again need perform no downcasting or destructuring, I escape the need for a giant match statement and I can still dispatch in O(1) time.

That comes at the cost of a lot of boilerplate. The massive match statement remains, just in visitable_opcode. One could regard that as the Rust analog of the C++ static initialization function that registers all instances of the sum type with my dispatcher: a one-time setup that can be hustled off to intialization code rather than my main execution path. But still: I need to implement a (trivial) trait for each variant, and the OpcodeVisitor trait has, again, 54 members.

Conclusion

At this point, I'm torn as to how to approach this in Rust. The boilerplate involved with my Visitor solution, along with the hassle of trying to manage the lifetimes of the references being handed around, leaves me leaning toward just using the match statement. The last weapon I have is Rust's macro facility: if I could say something like:

#[derive(OpcodeVisitor<()>)]
struct MyVisitor {
    // -- snip --
}

and:

#[derive(VisitableOpcode<()>)]
enum Opcode {
    // -- snip --
}

to get all that boilerplate generated for me automatically, that might settle the matter for me. If I ever get around to building something like that, I'll make it the subject of a future post.

05/24/21 07:43