Bags with Balls
5 min read
KorigamiK
This post highlights my first experience in solving a "hard" problem on CodeForces.
I logged in to CodeForces for the first time, I went to the Problemset
section
hoping to understand the "competitive coding" ways. I wasn't sure how the
difficulty rating system worked. Not only that, but I picked the first problem
in the list and tried to have a go.
The question
There are n bags, each bag contains m balls with numbers from 1 to m. For every [1, m], there is exactly one ball with number every in each bag.
You have to take exactly one ball from each bag (all bags are different, so, for example, taking the ball 1 from the first bag and the ball 2 from the second bag is not the same as taking the ball 2 from the first bag and the ball 1 from the second bag). After that, you calculate the number of balls with odd numbers among the ones you have taken. Let the number of these balls be , Your task is to calculate the sum of over all possible ways to take n balls, one from each bag.
Initial Thoughts
As soon as I saw I thought about Multi-index n-tuples where was a tuple of all the odd index selected balls.
But that was stupid, it's just the regular exponent. Still, it was pretty interesting of them to ask the sum of the number of balls to the
Spoilers
I was definitely stuck on this. There weren't even any solutions on the internet either. But I was able to find some similar problems and that's where WolframAlpha gave me the key to solve this problem.
Formulating the solution
We can consider the number of was to pick i
odd balls from n
bags each
containing m
balls. This can be found out by:
We basically choose i
boxes which we want the odd balls from and each of those
i
boxes has m+1/2 choices for odd indices, and then we want to choose even
balls from the n-i
remaining bags which have m/2 choices each.
To make the expressions easier to maintain we will
So the final answer can be formulated as
Where we sum through all the number of odd balls which can be selected (from 0 to n) each contributes to the answer.
Now the challenge was the simplify this expression, as it is now it will be way over the O(n) complexity that is required to solve these "hard" problems.
It seems awfully close to the well known binomial formula but the pesky makes it almost impossible to simplify further.
Luckily after some research on this, I stumbled upon an article on WolframAlpha about. Stirling numbers.
They are defined as:
The number of ways to partition a set of n labelled objects
into k nonempty unlabelled subsets
The stirling numbers of the second kind are denoted by the { a b } bracketes just like the binomial coefficient.
They also have an interesting property that:
Which will help us to do something about the the exponent. So without futher ado,
I present to you, the steps to the simplification:
Therefore, the answer can be simplified into
The first term can be precomputed into a 2-D dp array. The final complexity of the program should be
Code
This is the final code that I wrote using this approach
#include <iostream>
#define print(x) std::cout << x << "\n"
constexpr int MOD = 998244353;
using i64 = long long;
int mul(int x, int y)
{
return (i64)x * y % MOD;
}
int add(int x, int y)
{
x += y;
if (x >= MOD)
x -= MOD;
if (x < 0)
x += MOD;
return x;
}
int raiseToPower(int b, int p)
{
int res = 1;
for (; p; p >>= 1, b = (i64)b * b % MOD)
if (p & 1)
res = (i64)res * b % MOD;
return res;
}
constexpr int MAX_K = 2006;
int STIRLING[MAX_K][MAX_K];
int getStirling(int n, int m)
{
if (n < m || m < 0)
return 0;
return STIRLING[n][m];
}
void preCalcStirlings()
{
STIRLING[1][1] = 1;
for (int n = 2; n < MAX_K; ++n)
{
for (int m = 1; m < n; ++m)
// S(m,n) = nS(m−1,n) + S(m−1,n−1).
STIRLING[n][m] = add(mul(m, STIRLING[n - 1][m]), STIRLING[n - 1][m - 1]);
STIRLING[n][n] = 1;
}
}
int solve()
{
// Thank you @LacLic for the approach
int n, m, k;
std::cin >> n >> m >> k;
int sum = 0;
int descendingFact = n;
int x = (m + 1) / 2, y = m - x;
for (int j = 1; j <= k && j <= n; ++j)
{
sum = add(sum, mul(getStirling(k, j), mul(raiseToPower(x, j), mul(raiseToPower(m, n - j), descendingFact))));
descendingFact = mul(descendingFact, n - j);
}
return sum;
}
int main(int argc, char const *argv[])
{
preCalcStirlings();
int t;
std::cin >> t;
while (t--)
print(solve());
return 0;
}
Conclusion
I hope you found this thread interesting. I don't think I will be able to give much time to the competitive coding side of things. I'd be embarassed to say the time it took me to solve this problem, but isn't it always for the journey? For me, it's similar to a Sudoku or a puzzle that I would solve once in a while but not spend all my brain power on just doing them.