Okay so, just to make sure I'm understanding, the gist here is that in small batch inference, you don't care about gathering the experts for each token, whereas for training you care about routing tokens to experts without gathering the experts.
Don't you get a small version of the same problem for large prefills at inference time though?
> Don't you get a small version of the same problem for large prefills at inference time though?
Yep! This code will not be fast for prefill either. One todo we have is to actually use a different implementation for prefill. Luckily for prefill the overhead is not a massive deal.
Thank you for the post. You mentioned that the code would not be good for batch sizes > 1. I was wondering how bad exactly? Could you share for batch sizes typically used in inference settings (say for BS={1,4,16})?
Okay so, just to make sure I'm understanding, the gist here is that in small batch inference, you don't care about gathering the experts for each token, whereas for training you care about routing tokens to experts without gathering the experts.
Don't you get a small version of the same problem for large prefills at inference time though?
> Don't you get a small version of the same problem for large prefills at inference time though?
Yep! This code will not be fast for prefill either. One todo we have is to actually use a different implementation for prefill. Luckily for prefill the overhead is not a massive deal.
Thank you for the post. You mentioned that the code would not be good for batch sizes > 1. I was wondering how bad exactly? Could you share for batch sizes typically used in inference settings (say for BS={1,4,16})?
This approach is only really good for BS=1, I think it might be reasonable for like BS=2 or 3 (although depends on the sparsity).
For larger batch sizes you want different kernel strategies that we don't discuss in this post.