feat: transparent client migration on DC connection death

When Telegram closes a DC TCP connection, instead of dropping all
multiplexed clients, the proxy now remaps them to a surviving (or
freshly-spawned) replacement DC connection.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Sergey Prokhorov 2026-04-08 00:46:40 +02:00
parent ed2fa40ee5
commit 492956a598
No known key found for this signature in database
GPG key ID: 1C570244E4EF3337
7 changed files with 341 additions and 13 deletions

52
doc/migration-flow.md Normal file
View file

@ -0,0 +1,52 @@
# Transparent client migration on DC connection death
Telegram periodically closes the TCP connection to the proxy ("DC connection
rotation", typically every 3070 s). Instead of dropping all clients
multiplexed on that connection, the proxy remaps each idle client to a
surviving (or freshly-started) DC connection transparently.
**Key actors:**
- `mtp_down_conn (old)` — the dying downstream connection process
- `mtp_dc_pool` — pool managing all downstream connections for one DC
- `mtp_handler` — one process per connected Telegram client
- `mtp_down_conn (new)` — replacement downstream spawned by the pool
```mermaid
sequenceDiagram
participant TG as Telegram
participant OldDown as mtp_down_conn (old)
participant Pool as mtp_dc_pool
participant Handler as mtp_handler
participant NewDown as mtp_down_conn (new)
TG->>OldDown: TCP close
OldDown->>Pool: downstream_closing(self()) [sync]
Pool-->>Pool: remove OldDown from ds_store + monitors
Pool-->>NewDown: spawn & connect (maybe_restart_connection)
Pool-->>OldDown: ok
OldDown->>Handler: migrate(OldDown) [cast, to all known upstreams]
Note over OldDown: drain_mailbox(5000)
alt upstream_new in mailbox
Note over Pool,OldDown: Race: pool processed a {get} call just before<br/>downstream_closing — upstream_new cast already queued
Pool-->>OldDown: upstream_new(Handler2, Opts) [cast, queued]
OldDown->>Handler2: migrate(OldDown) [cast, immediately]
end
alt Handler was blocked in down_send
Handler-->>OldDown: {send, Data} [call, in mailbox]
OldDown-->>Handler: {error, migrating}
Note over Handler: metric[mid_send] → stop<br/>(client reconnects and resends)
else Handler was idle
Handler->>Pool: migrate(OldDown, self(), Opts) [sync]
Pool-->>Pool: remove Handler from upstreams map
Pool->>NewDown: upstream_new(Handler, Opts) [cast]
Pool-->>Handler: NewDown pid
Note over Handler: down = NewDown<br/>metric[ok]
end
Note over OldDown: stop {shutdown, downstream_migrated}
```

View file

@ -17,6 +17,8 @@
-export([start_link/1,
get/3,
return/2,
downstream_closing/2,
migrate/4,
add_connection/1,
ack_connected/2,
status/1,
@ -80,6 +82,23 @@ get(Pool, Upstream, #{addr := _} = Opts) ->
return(Pool, Upstream) ->
gen_server:cast(Pool, {return, Upstream}).
%% Called by a downstream that received tcp_closed with active upstreams.
%% Removes the downstream from the pool store synchronously so it won't receive
%% new upstreams while its handlers are migrating.
-spec downstream_closing(pid(), downstream()) -> ok.
downstream_closing(Pool, Downstream) ->
gen_server:call(Pool, {downstream_closing, Downstream}).
%% Atomically return an upstream from a dying downstream and assign it to a new one.
%% Avoids the "attempt to release unknown connection" warning that return+get would cause.
-spec migrate(pid(), downstream(), upstream(),
#{addr := mtp_config:netloc_v4v6(),
ad_tag => binary(),
packet_layer => mtp_down_conn:packet_layer()}) ->
downstream() | {error, empty | not_found}.
migrate(Pool, OldDown, Upstream, Opts) ->
gen_server:call(Pool, {migrate, OldDown, Upstream, Opts}).
add_connection(Pool) ->
gen_server:call(Pool, add_connection, 10000).
@ -108,6 +127,11 @@ handle_call({get, Upstream, Opts}, _From, State) ->
{Downstream, State1} ->
{reply, Downstream, State1}
end;
handle_call({downstream_closing, Downstream}, _From, State) ->
{reply, ok, handle_downstream_closing(Downstream, State)};
handle_call({migrate, OldDown, Upstream, Opts}, _From, State) ->
{Reply, State1} = handle_migrate(OldDown, Upstream, Opts, State),
{reply, Reply, State1};
handle_call(add_connection, _From, State) ->
State1 = connect(State),
{reply, ok, State1};
@ -178,6 +202,39 @@ handle_return(Upstream, #state{downstreams = Ds,
St#state{downstreams = Ds1,
upstreams = Us1}.
%% Remove a dying downstream from the store before its handlers migrate.
%% Called synchronously by mtp_down_conn so removal is complete before
%% {migrate, Self} is sent to upstreams.
handle_downstream_closing(Downstream, #state{downstreams = Ds,
downstream_monitors = DsM,
pending_downstreams = Pending} = St) ->
DsM1 = maps:filter(
fun(MonRef, Pid) when Pid =:= Downstream ->
erlang:demonitor(MonRef, [flush]),
false;
(_, _) ->
true
end, DsM),
Ds1 = ds_remove(Downstream, Ds),
Pending1 = lists:delete(Downstream, Pending),
maybe_restart_connection(
St#state{downstreams = Ds1,
downstream_monitors = DsM1,
pending_downstreams = Pending1}).
%% Atomically reassign an upstream from its dying downstream to a new one.
handle_migrate(_OldDown, Upstream, Opts, #state{upstreams = Us} = St) ->
case maps:take(Upstream, Us) of
{{_AnyOldDown, OldMonRef}, Us1} ->
erlang:demonitor(OldMonRef, [flush]),
case handle_get(Upstream, Opts, St#state{upstreams = Us1}) of
{empty, St1} -> {{error, empty}, St1};
{NewDown, St1} -> {NewDown, St1}
end;
error ->
{{error, not_found}, St}
end.
handle_down(MonRef, Pid, Reason, #state{downstreams = Ds,
downstream_monitors = DsM,
upstreams = Us,
@ -196,6 +253,8 @@ handle_down(MonRef, Pid, Reason, #state{downstreams = Ds,
case Reason of
{shutdown, downstream_socket_closed} ->
?LOG_INFO("Downstream=~p closed (no active clients)", [Pid]);
{shutdown, downstream_migrated} ->
?LOG_INFO("Downstream=~p finished migrating clients", [Pid]);
_ ->
?LOG_ERROR("Downstream=~p is down. reason=~p", [Pid, Reason])
end,
@ -204,7 +263,7 @@ handle_down(MonRef, Pid, Reason, #state{downstreams = Ds,
downstreams = Ds1,
downstream_monitors = DsM1});
_ ->
?LOG_ERROR("Unexpected DOWN. ref=~p, pid=~p, reason=~p", [MonRef, Pid, Reason]),
?LOG_WARNING("Unexpected DOWN. ref=~p, pid=~p, reason=~p", [MonRef, Pid, Reason]),
St
end
end.

View file

@ -101,7 +101,7 @@ shutdown(Conn) ->
gen_server:cast(Conn, shutdown).
%% To be called by upstream
-spec send(handle(), iodata()) -> ok | {error, unknown_upstream}.
-spec send(handle(), iodata()) -> ok | {error, unknown_upstream | migrating}.
send(Conn, Data) ->
gen_server:call(Conn, {send, Data}, ?SEND_TIMEOUT * 2).
@ -166,10 +166,24 @@ handle_info({tcp, Sock, Data}, #state{sock = Sock, dc_id = DcId} = S) ->
{ok, S1} = handle_downstream_data(Data, S),
activate_if_no_overflow(S1),
{noreply, S1};
handle_info({tcp_closed, Sock}, #state{sock = Sock, upstreams = Ups} = State) ->
handle_info({tcp_closed, Sock}, #state{sock = Sock, upstreams = Ups, pool = Pool} = State) ->
case map_size(Ups) of
0 -> {stop, {shutdown, downstream_socket_closed}, State};
_ -> {stop, downstream_socket_closed, State}
0 ->
{stop, {shutdown, downstream_socket_closed}, State};
N ->
%% Remove self from pool first so no new upstreams can be assigned.
ok = mtp_dc_pool:downstream_closing(Pool, self()),
?LOG_INFO("Downstream socket closed with ~p active client(s); migrating", [N]),
%% Notify all known upstreams to migrate immediately.
[mtp_handler:migrate(Upstream, self()) || Upstream <- maps:keys(Ups)],
%% Drain remaining mailbox messages:
%% - {send,...} calls: reply {error, migrating} to unblock callers
%% - upstream_new casts: handlers assigned to us by the pool just
%% before downstream_closing ran; migrate them immediately
NDrained = drain_mailbox(5000),
NDrained > 0 andalso
?LOG_INFO("Drained ~p pending send call(s) during migration", [NDrained]),
{stop, {shutdown, downstream_migrated}, State}
end;
handle_info({tcp_error, Sock, Reason}, #state{sock = Sock} = State) ->
{stop, {downstream_tcp_error, Reason}, State};
@ -194,6 +208,9 @@ handle_info(handshake_timeout, #state{stage = Stage, dc_id = DcId} = St) ->
end.
terminate({shutdown, downstream_migrated}, _State) ->
%% Normal shutdown during migration; no need to log or notify upstreams.
ok;
terminate(Reason, #state{upstreams = Ups}) ->
NUps = map_size(Ups),
case Reason of
@ -207,8 +224,8 @@ terminate(Reason, #state{upstreams = Ups}) ->
lists:foreach(
fun(Upstream) ->
ok = mtp_handler:send(Upstream, {close_ext, Self})
end, maps:keys(Ups)),
ok.
end, maps:keys(Ups)).
code_change(_OldVsn, State, _Extra) ->
{ok, State}.
@ -471,6 +488,26 @@ non_ack_cleanup_upstream(Upstream, #state{non_ack_count = Cnt,
St#state{non_ack_count = Cnt - UpsCnt,
non_ack_bytes = Oct - UpsOct}).
%% Drain pending messages from our mailbox during migration.
%% - gen_server:call({send,_}): reply {error, migrating} so callers unblock
%% - gen_server:cast({upstream_new,...}): send {migrate} immediately
%% Timeout controls how long to wait for the next message before giving up.
%% Returns count of drained send calls.
drain_mailbox(Timeout) ->
drain_mailbox(Timeout, 0).
drain_mailbox(Timeout, NSend) ->
receive
{'$gen_call', From, {send, _Data}} ->
gen_server:reply(From, {error, migrating}),
drain_mailbox(Timeout, NSend + 1);
{'$gen_cast', {upstream_new, Upstream, _Opts}} ->
mtp_handler:migrate(Upstream, self()),
drain_mailbox(Timeout, NSend)
after Timeout ->
NSend
end.
%%
%% Connect / handshake

View file

@ -10,7 +10,7 @@
-behaviour(ranch_protocol).
%% API
-export([start_link/3, start_link/4, send/2]).
-export([start_link/3, start_link/4, send/2, migrate/2]).
-export([hex/1, unhex/1]).
-export([keys_str/0]).
@ -80,6 +80,10 @@ keys_str() ->
send(Upstream, Packet) ->
gen_server:cast(Upstream, Packet).
-spec migrate(pid(), OldDown :: mtp_down_conn:handle()) -> ok.
migrate(Upstream, OldDown) ->
gen_server:cast(Upstream, {migrate, OldDown}).
%% Callbacks
%% Custom gen_server init
@ -175,6 +179,27 @@ handle_cast({close_ext, Down}, #state{down = Down, sock = USock, transport = UTr
?LOG_DEBUG("asked to close connection by downstream"),
ok = UTrans:close(USock),
{stop, normal, S#state{down = undefined}};
handle_cast({migrate, OldDown}, #state{down = OldDown, dc_id = {_DcId, Pool},
codec = Codec, addr = Addr,
ad_tag = AdTag, listener = Listener} = S) ->
{PacketLayerMod, _} = mtp_codec:info(packet, Codec),
Opts = #{addr => Addr, ad_tag => AdTag, packet_layer => PacketLayerMod},
case mtp_dc_pool:migrate(Pool, OldDown, self(), Opts) of
{error, Reason} ->
?LOG_DEBUG("Migration failed (~p), closing client", [Reason]),
true = is_atom(Reason),
mtp_metric:count_inc([?APP, downstream_migration, total], 1,
#{labels => [Listener, Reason]}),
{stop, normal, S#state{down = undefined}};
NewDown ->
?LOG_DEBUG("Migrated from ~p to ~p", [OldDown, NewDown]),
mtp_metric:count_inc([?APP, downstream_migration, total], 1,
#{labels => [Listener, ok]}),
{noreply, S#state{down = NewDown}}
end;
handle_cast({migrate, _StaleDown}, S) ->
%% Stale migrate from a previous down_conn already migrated, ignore.
{noreply, S};
handle_cast({simple_ack, Down, Confirm}, #state{down = Down} = S) ->
?LOG_INFO("Simple ack: ~p, ~p", [Down, Confirm]),
{noreply, S};
@ -613,13 +638,19 @@ up_send_raw(Data, #state{sock = Sock,
end
end, #{labels => [Listener]}).
down_send(Packet, #state{down = Down} = S) ->
down_send(Packet, #state{down = Down, listener = Listener} = S) ->
%% ?LOG_DEBUG(">Down: ~p", [Packet]),
case mtp_down_conn:send(Down, Packet) of
ok ->
{ok, S};
{error, unknown_upstream} ->
handle_unknown_upstream(S)
handle_unknown_upstream(S);
{error, migrating} ->
%% DC connection is closing; this packet was never sent to TG.
%% Stop the handler so the client reconnects and resends.
mtp_metric:count_inc([?APP, downstream_migration, total], 1,
#{labels => [Listener, mid_send]}),
throw({stop, normal, S})
end.
handle_unknown_upstream(#state{down = Down, sock = USock, transport = UTrans} = S) ->

View file

@ -5,7 +5,8 @@
-export([start/2,
stop/1,
get_rpc_handler_state/1]).
get_rpc_handler_state/1,
close_connection/1]).
-export([start_link/3,
ranch_init/1]).
-export([init/1,
@ -61,6 +62,10 @@ stop(Id) ->
get_rpc_handler_state(Pid) ->
gen_statem:call(Pid, get_rpc_handler_state).
%% Close the server-side TCP socket, simulating Telegram rotating the connection.
close_connection(Pid) ->
gen_statem:call(Pid, close_connection).
%% Callbacks
start_link(Ref, Transport, Opts) ->
@ -159,6 +164,9 @@ on_tunnel(info, {tcp, _Sock, TcpData}, #t_state{codec = Codec0} = S) ->
{keep_state, activate(S2#t_state{codec = Codec1})};
on_tunnel({call, From}, get_rpc_handler_state, #t_state{rpc_handler_state = HSt}) ->
{keep_state_and_data, [{reply, From, HSt}]};
on_tunnel({call, From}, close_connection, #t_state{sock = Sock, transport = Transport}) ->
Transport:close(Sock),
{stop_and_reply, normal, [{reply, From, ok}]};
on_tunnel(Type, Event, S) ->
handle_event(Type, Event, ?FUNCTION_NAME, S).

View file

@ -0,0 +1,17 @@
%% @doc rpc_handler for mtp_test_middle_server that echoes packets and reports
%% each one to a registered process named `mtp_test_rpc_sink'.
%% The report message is `{rpc_from, self(), ConnId, Data}', where `self()' is
%% the mtp_test_middle_server Ranch connection pid useful for tests that need
%% to identify which DC connection a client is multiplexed on and close it.
-module(mtp_test_reporter_rpc).
-export([init/1,
handle_rpc/2]).
init(_) ->
#{}.
handle_rpc({data, ConnId, Data}, St) ->
mtp_test_rpc_sink ! {rpc_from, self(), ConnId, Data},
{rpc, {proxy_ans, ConnId, Data}, St};
handle_rpc({remote_closed, ConnId}, St) ->
{noreply, St#{ConnId => closed}}.

View file

@ -27,7 +27,10 @@
domain_fronting_replay_case/1,
per_sni_secrets_on_case/1,
per_sni_secrets_wrong_secret_case/1,
malformed_tls_hello_decode_error_case/1
malformed_tls_hello_decode_error_case/1,
downstream_migration_case/1,
downstream_migration_multi_case/1,
downstream_migration_empty_pool_case/1
]).
-export([set_env/2,
@ -781,6 +784,103 @@ malformed_tls_hello_decode_error_case(Cfg) when is_list(Cfg) ->
1, mtp_test_metric:get_tags(
count, [?APP, protocol_error, total], [?FUNCTION_NAME, tls_bad_client_hello])).
%% @doc Client survives a DC connection rotation (1 client, 2 DC connections available).
downstream_migration_case({pre, Cfg}) ->
setup_single(?FUNCTION_NAME, 10000 + ?LINE,
#{init_dc_connections => 2, rpc_handler => mtp_test_reporter_rpc}, Cfg);
downstream_migration_case({post, Cfg}) ->
stop_single(Cfg);
downstream_migration_case(Cfg) when is_list(Cfg) ->
DcId = ?config(dc_id, Cfg),
Host = ?config(mtp_host, Cfg),
Port = ?config(mtp_port, Cfg),
Secret = ?config(mtp_secret, Cfg),
Pool = mtp_dc_pool:dc_to_pool_name(DcId),
register(mtp_test_rpc_sink, self()),
try
Cli = mtp_test_client:connect(Host, Port, Secret, DcId, mtp_secure),
Cli1 = ping(Cli),
%% Receive the reporter notification to learn which Ranch/middle-server pid
%% this client's DC connection tunnels through.
ServerPid = receive {rpc_from, Pid, _, _} -> Pid end,
ok = mtp_test_middle_server:close_connection(ServerPid),
%% Wait until handler has successfully migrated to the surviving downstream.
ok = mtp_test_metric:wait_for_value(
count, [?APP, downstream_migration, total],
[?FUNCTION_NAME, ok], 1, 5000),
%% Client must still work after migration.
Cli2 = ping(Cli1),
%% Pool tracking must be clean: exactly 1 upstream registered.
?assertMatch(#{n_upstreams := 1}, mtp_dc_pool:status(Pool)),
ok = mtp_test_client:close(Cli2)
after
unregister(mtp_test_rpc_sink)
end.
%% @doc All clients survive when one of two DC connections is rotated.
downstream_migration_multi_case({pre, Cfg}) ->
setup_single(?FUNCTION_NAME, 10000 + ?LINE,
#{init_dc_connections => 2, rpc_handler => mtp_test_reporter_rpc}, Cfg);
downstream_migration_multi_case({post, Cfg}) ->
stop_single(Cfg);
downstream_migration_multi_case(Cfg) when is_list(Cfg) ->
DcId = ?config(dc_id, Cfg),
Host = ?config(mtp_host, Cfg),
Port = ?config(mtp_port, Cfg),
Secret = ?config(mtp_secret, Cfg),
Pool = mtp_dc_pool:dc_to_pool_name(DcId),
N = 3,
register(mtp_test_rpc_sink, self()),
try
Clients = [mtp_test_client:connect(Host, Port, Secret, DcId, mtp_secure)
|| _ <- lists:seq(1, N)],
Clients1 = [ping(C) || C <- Clients],
%% Drain all N {rpc_from,...} messages, grouping by ServerPid to find
%% which DC connection each client landed on.
Groups = lists:foldl(fun(_, Acc) ->
receive {rpc_from, Pid, _, _} ->
maps:update_with(Pid, fun(C) -> C + 1 end, 1, Acc)
end
end, #{}, lists:seq(1, N)),
%% Close the DC connection carrying the most clients.
{ServerPid, NOnServer} = hd(lists:reverse(lists:keysort(2, maps:to_list(Groups)))),
ok = mtp_test_middle_server:close_connection(ServerPid),
%% Wait until exactly NOnServer clients have successfully migrated.
ok = mtp_test_metric:wait_for_value(
count, [?APP, downstream_migration, total],
[?FUNCTION_NAME, ok], NOnServer, 5000),
Clients2 = [ping(C) || C <- Clients1],
?assertMatch(#{n_upstreams := N}, mtp_dc_pool:status(Pool)),
[ok = mtp_test_client:close(C) || C <- Clients2]
after
unregister(mtp_test_rpc_sink)
end,
ok.
%% @doc When pool is empty after DC rotation, client closes gracefully.
downstream_migration_empty_pool_case({pre, Cfg}) ->
setup_single(?FUNCTION_NAME, 10000 + ?LINE,
#{init_dc_connections => 0}, Cfg);
downstream_migration_empty_pool_case({post, Cfg}) ->
stop_single(Cfg);
downstream_migration_empty_pool_case(Cfg) when is_list(Cfg) ->
DcId = ?config(dc_id, Cfg),
Host = ?config(mtp_host, Cfg),
Port = ?config(mtp_port, Cfg),
Secret = ?config(mtp_secret, Cfg),
DcCfg = ?config(dc_conf, Cfg),
Pool = mtp_dc_pool:dc_to_pool_name(DcId),
%% Manually add one connection and wait for it to be ready
ok = mtp_dc_pool:add_connection(Pool),
ok = wait_for_pool_status(Pool, fun(S) -> maps:get(n_downstreams, S) >= 1 end, 5000),
Cli = mtp_test_client:connect(Host, Port, Secret, DcId, mtp_secure),
_Cli1 = ping(Cli),
[Conn] = mtp_test_datacenter:middle_connections(DcCfg),
ok = mtp_test_middle_server:close_connection(Conn),
%% Pool stays empty (init_dc_connections=0 so no replacement spawned)
%% Client must close gracefully
?assertEqual({error, closed}, mtp_test_client:recv_packet(_Cli1, 2000)).
setup_single(Name, MtpPort, DcCfg0, Cfg) ->
setup_single(Name, "127.0.0.1", MtpPort, DcCfg0, Cfg).
@ -796,7 +896,11 @@ setup_single(Name, MtpIpStr, MtpPort, DcCfg0, Cfg) ->
secret => Secret,
tag => <<"dcbe8f1493fa4cd9ab300891c0b5b326">>}],
application:load(mtproto_proxy),
Cfg1 = set_env([{ports, Listeners}], Cfg),
AppEnv = case maps:find(init_dc_connections, DcCfg0) of
{ok, N} -> [{init_dc_connections, N}];
error -> []
end,
Cfg1 = set_env([{ports, Listeners}] ++ AppEnv, Cfg),
{ok, DcCfg} = mtp_test_datacenter:start_dc(PubKey, DcConf, DcCfg0),
{ok, _} = application:ensure_all_started(mtproto_proxy),
{ok, MtpIp} = inet:parse_address(MtpIpStr),
@ -851,3 +955,23 @@ ping(Cli0) ->
{ok, Packet, Cli2} = mtp_test_client:recv_packet(Cli1, 1000),
?assertEqual(Data, Packet),
Cli2.
wait_for_pool_status(Pool, Pred, Timeout) ->
Deadline = erlang:monotonic_time(millisecond) + Timeout,
wait_for_pool_status_loop(Pool, Pred, Deadline).
wait_for_pool_status_loop(Pool, Pred, Deadline) ->
Status = mtp_dc_pool:status(Pool),
case Pred(Status) of
true ->
ok;
false ->
Remaining = Deadline - erlang:monotonic_time(millisecond),
case Remaining > 0 of
true ->
timer:sleep(50),
wait_for_pool_status_loop(Pool, Pred, Deadline);
false ->
{error, {timeout, Status}}
end
end.