438B


  Posted on 09 Dec 2014
disjoint sets

Description


Here is a undirected graph with nodes and edges. () Each node has a weight . defines the simple route from to with the largest . defines the weight of the node with the least weight on that simple route. Output the average of of all possible pairs of nodes on the graph.

Tutorial


Sort the nodes with their weights in decreasing order. Starting from an empty graph, add the nodes to the graph one by one. Each time some components may become connected by the new node added. where and are in two different components that to be connected by the new node, equals to the weight of the new node. Use disjoint sets to work it out. Join the components to that new node one by one. Add the sum of the s to the answer during each joint.

Solution


#include <vector>
#include <cstdio>
#include <algorithm>
using namespace std;

const int MAX_N = int(1e5) + 10;

int n, m;
vector<int> edge[MAX_N];
pair<int, int> animal[MAX_N];
bool vis[MAX_N];

struct Disjoint_sets
{
	int father[MAX_N];
	int num[MAX_N];

	Disjoint_sets()
	{}

	Disjoint_sets(int n)
	{
		for (int i = 0; i < n; i++)
		{
			father[i] = i;
			num[i] = 1;
		}
	}

	int root(int a)
	{
		int ret = a;
		while (father[ret] != ret)
			ret = father[ret];
		while (father[a] != a)
		{
			int b = a;
			a = father[a];
			father[b] = ret;
		}
		return ret;
	}

	void join(int a, int b) // b is the root then
	{
		int num_a = num[root(a)];
		father[root(a)] = father[root(b)];
		num[root(b)] += num_a;
	}
};


int main()
{
	//input
	scanf("%d%d", &n, &m);
	for (int i = 0; i < n; i++)
	{
		int a;
		scanf("%d", &a);
		animal[i] = make_pair(a, i);
	}
	for (int i = 0; i < m; i++)
	{
		int a, b;
		scanf("%d%d", &a, &b);
		a--;
		b--;
		edge[a].push_back(b);
		edge[b].push_back(a);
	}

	//work
	Disjoint_sets d_sets(n);
	fill(vis, vis + n, 0);
	sort(animal, animal + n);
	long long ans = 0;
	for (int i = n - 1; i >= 0; i--)
	{
		int u = animal[i].second;
		int min_num = animal[i].first;
		vis[u] = true;
		for (int j = 0; j < (int)edge[u].size(); j++)
		{
			int v = edge[u][j];
			if (!vis[v])
				continue;
			if (d_sets.root(v) != d_sets.root(u))
			{
				ans += 1LL * min_num * d_sets.num[d_sets.root(v)] * d_sets.num[d_sets.root(u)];
				d_sets.join(v, u);
			}
		}
	}

	double final_ans = ans * 2.0 / n / (n - 1);
	printf("%.12f\n", final_ans);
	return 0;
}