%%
%% %CopyrightBegin%
%%
%% Copyright Ericsson AB 2007-2014. All Rights Reserved.
%%
%% The contents of this file are subject to the Erlang Public License,
%% Version 1.1, (the "License"); you may not use this file except in
%% compliance with the License. You should have received a copy of the
%% Erlang Public License along with this software. If not, it can be
%% retrieved online at http://www.erlang.org/.2
%%
%% Software distributed under the License is distributed on an "AS IS"
%% basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See
%% the License for the specific language governing rights and limitations
%% under the License.
%%
%% %CopyrightEnd%
%%

%%

-module(ssl_ECC_SUITE).

%% Note: This directive should only be used in test suites.
-compile(export_all).

-include_lib("common_test/include/ct.hrl").
-include_lib("public_key/include/public_key.hrl").

%%--------------------------------------------------------------------
%% Common Test interface functions -----------------------------------
%%--------------------------------------------------------------------

suite() -> [{ct_hooks,[ts_install_cth]}].

all() ->
    [
     {group, 'tlsv1.2'},
     {group, 'tlsv1.1'},
     {group, 'tlsv1'}
    ].

groups() ->
    [
     {'tlsv1.2', [], all_versions_groups()},
     {'tlsv1.1', [], all_versions_groups()},
     {'tlsv1', [], all_versions_groups()},
     {'erlang_server', [], key_cert_combinations()},
     {'erlang_client', [], key_cert_combinations()},
     {'erlang', [], key_cert_combinations()}
    ].

all_versions_groups ()->
    [{group, 'erlang_server'},
     {group, 'erlang_client'},
     {group, 'erlang'}
    ].

key_cert_combinations() ->
    [client_ecdh_server_ecdh,
     client_rsa_server_ecdh,
     client_ecdh_server_rsa,
     client_rsa_server_rsa,
     client_ecdsa_server_ecdsa,
     client_ecdsa_server_rsa,
     client_rsa_server_ecdsa
    ].

%%--------------------------------------------------------------------
init_per_suite(Config0) ->
    end_per_suite(Config0),
    try crypto:start() of
	ok ->
	    %% make rsa certs using oppenssl
	    Result =
		(catch make_certs:all(?config(data_dir, Config0),
				      ?config(priv_dir, Config0))),
	    ct:log("Make certs  ~p~n", [Result]),
	    Config1 = ssl_test_lib:make_ecdsa_cert(Config0),
	    Config2 = ssl_test_lib:make_ecdh_rsa_cert(Config1),
	    ssl_test_lib:cert_options(Config2)
    catch _:_ ->
	    {skip, "Crypto did not start"}
    end.

end_per_suite(_Config) ->
    application:stop(ssl),
    application:stop(crypto).

%%--------------------------------------------------------------------
init_per_group(erlang_client = Group, Config) ->
    case ssl_test_lib:is_sane_ecc(openssl) of
	true ->
	    common_init_per_group(Group, [{server_type, openssl},
					  {client_type, erlang} | Config]);
	false ->
	    {skip, "Known ECC bug in openssl"}
    end;

init_per_group(erlang_server = Group, Config) ->
    case ssl_test_lib:is_sane_ecc(openssl) of 
	true ->
	    common_init_per_group(Group, [{server_type, erlang},
					  {client_type, openssl} | Config]);
	false ->
	    {skip, "Known ECC bug in openssl"}
    end;
	
init_per_group(erlang = Group, Config) ->
     case ssl_test_lib:sufficient_crypto_support(Group) of
	 true ->
	     common_init_per_group(Group, [{server_type, erlang},
					   {client_type, erlang} | Config]);
	 false ->
	      {skip, "Crypto does not support ECC"}
     end;

init_per_group(openssl = Group, Config) ->
     case ssl_test_lib:sufficient_crypto_support(Group) of
	 true ->
	     common_init_per_group(Group, [{server_type, openssl},
					   {client_type, openssl} | Config]);
	 false ->
	      {skip, "Crypto does not support ECC"}
     end;
		 
init_per_group(Group, Config) ->
    common_init_per_group(Group, Config).

common_init_per_group(GroupName, Config) ->
    case ssl_test_lib:is_tls_version(GroupName) of
	true ->
	    ssl_test_lib:init_tls_version(GroupName),
	    [{tls_version, GroupName} | Config];
	_ ->
	   openssl_check(GroupName, Config)
    end.

end_per_group(_GroupName, Config) ->
    Config.

%%--------------------------------------------------------------------

init_per_testcase(TestCase, Config) ->
    ct:log("TLS/SSL version ~p~n ", [tls_record:supported_protocol_versions()]),
    ct:log("Ciphers: ~p~n ", [ ssl:cipher_suites()]),
    end_per_testcase(TestCase, Config),
    ssl:start(),	
    Config.

end_per_testcase(_TestCase, Config) ->     
    application:stop(ssl),
    Config.

%%--------------------------------------------------------------------
%% Test Cases --------------------------------------------------------
%%--------------------------------------------------------------------

client_ecdh_server_ecdh(Config) when is_list(Config) ->
    COpts =  ?config(client_ecdh_rsa_opts, Config),
    SOpts = ?config(server_ecdh_rsa_verify_opts, Config),
    basic_test(COpts, SOpts, Config).
    
client_ecdh_server_rsa(Config)  when is_list(Config) ->
    COpts =  ?config(client_ecdh_rsa_opts, Config),
    SOpts = ?config(server_ecdh_rsa_verify_opts, Config),
    basic_test(COpts, SOpts, Config).
  
client_rsa_server_ecdh(Config)  when is_list(Config) ->
    COpts =  ?config(client_ecdh_rsa_opts, Config),
    SOpts = ?config(server_ecdh_rsa_verify_opts, Config),
    basic_test(COpts, SOpts, Config).
   
client_rsa_server_rsa(Config)  when is_list(Config) ->
    COpts =  ?config(client_verification_opts, Config),
    SOpts = ?config(server_verification_opts, Config),
    basic_test(COpts, SOpts, Config).
   
client_ecdsa_server_ecdsa(Config)  when is_list(Config) ->
    COpts =  ?config(client_ecdsa_opts, Config),
    SOpts = ?config(server_ecdsa_verify_opts, Config),
    basic_test(COpts, SOpts, Config).

client_ecdsa_server_rsa(Config)  when is_list(Config) ->
    COpts =  ?config(client_ecdsa_opts, Config),
    SOpts = ?config(server_ecdsa_verify_opts, Config),
    basic_test(COpts, SOpts, Config).

client_rsa_server_ecdsa(Config)  when is_list(Config) ->
    COpts =  ?config(client_ecdsa_opts, Config),
    SOpts = ?config(server_ecdsa_verify_opts, Config),
    basic_test(COpts, SOpts, Config).

%%--------------------------------------------------------------------
%% Internal functions ------------------------------------------------
%%--------------------------------------------------------------------
basic_test(COpts, SOpts, Config) ->
    basic_test(proplists:get_value(certfile, COpts), 
	       proplists:get_value(keyfile, COpts), 
	       proplists:get_value(cacertfile, COpts), 
	       proplists:get_value(certfile, SOpts), 
	       proplists:get_value(keyfile, SOpts), 
	       proplists:get_value(cacertfile, SOpts), 
	       Config).
    
basic_test(ClientCert, ClientKey, ClientCA, ServerCert, ServerKey, ServerCA, Config) ->
    SType = ?config(server_type, Config),
    CType = ?config(client_type, Config),
    {Server, Port} = start_server(SType,
				  ClientCA, ServerCA,
				  ServerCert,
				  ServerKey,
				  Config),
    Client = start_client(CType, Port, ServerCA, ClientCA,
			  ClientCert,
			  ClientKey, Config),
    check_result(Server, SType, Client, CType),
    close(Server, Client).    

start_client(openssl, Port, CA, OwnCa, Cert, Key, Config) ->
    PrivDir = ?config(priv_dir, Config),
    NewCA = new_ca(filename:join(PrivDir, "new_ca.pem"), CA, OwnCa),
    Version = tls_record:protocol_version(tls_record:highest_protocol_version([])),
    Cmd = "openssl s_client -verify 2 -port " ++ integer_to_list(Port) ++  ssl_test_lib:version_flag(Version) ++
	" -cert " ++ Cert ++ " -CAfile " ++ NewCA
	++ " -key " ++ Key ++ " -host localhost -msg -debug",
    OpenSslPort =  open_port({spawn, Cmd}, [stderr_to_stdout]),
    true = port_command(OpenSslPort, "Hello world"),
    OpenSslPort;
start_client(erlang, Port, CA, _, Cert, Key, Config) ->
    {ClientNode, _, Hostname} = ssl_test_lib:run_where(Config),
    ssl_test_lib:start_client([{node, ClientNode}, {port, Port},
			       {host, Hostname},
			       {from, self()},
			       {mfa, {ssl_test_lib, send_recv_result_active, []}},
			       {options, [{verify, verify_peer}, 
					  {cacertfile, CA},
					  {certfile, Cert}, {keyfile, Key}]}]).

start_server(openssl, CA, OwnCa, Cert, Key, Config) ->
    PrivDir = ?config(priv_dir, Config),
    NewCA = new_ca(filename:join(PrivDir, "new_ca.pem"), CA, OwnCa),

    Port = ssl_test_lib:inet_port(node()),
    Version = tls_record:protocol_version(tls_record:highest_protocol_version([])),
    Cmd = "openssl s_server -accept " ++ integer_to_list(Port) ++ ssl_test_lib:version_flag(Version) ++
	" -verify 2 -cert " ++ Cert ++ " -CAfile " ++ NewCA
	++ " -key " ++ Key ++ " -msg -debug",
    OpenSslPort =  open_port({spawn, Cmd}, [stderr_to_stdout]),
    ssl_test_lib:wait_for_openssl_server(),
    true = port_command(OpenSslPort, "Hello world"),
    {OpenSslPort, Port};

start_server(erlang, CA, _, Cert, Key, Config) ->

    {_, ServerNode, _} = ssl_test_lib:run_where(Config),
    Server = ssl_test_lib:start_server([{node, ServerNode}, {port, 0},
			       {from, self()},
			       {mfa, {ssl_test_lib,
				      send_recv_result_active,
				      []}},
			       {options,
				[{verify, verify_peer}, {cacertfile, CA},
				 {certfile, Cert}, {keyfile, Key}]}]),
    {Server, ssl_test_lib:inet_port(Server)}.

check_result(Server, erlang, Client, erlang) ->
    ssl_test_lib:check_result(Server, ok, Client, ok);
check_result(Server, erlang, _, _) ->
    ssl_test_lib:check_result(Server, ok);
check_result(_, _, Client, erlang) ->
    ssl_test_lib:check_result(Client, ok);
check_result(_,openssl, _, openssl) ->
    ok.

openssl_check(erlang, Config) ->
    Config;
openssl_check(_, Config) ->
    TLSVersion = ?config(tls_version, Config),
    case ssl_test_lib:check_sane_openssl_version(TLSVersion) of
	true ->
	    Config;
	false ->
	    {skip, "TLS version not supported by openssl"}
    end.

close(Port1, Port2) when is_port(Port1), is_port(Port2) ->
    ssl_test_lib:close_port(Port1),
    ssl_test_lib:close_port(Port2);
close(Port, Pid) when is_port(Port) ->
    ssl_test_lib:close_port(Port),
    ssl_test_lib:close(Pid);
close(Pid, Port) when is_port(Port) ->
    ssl_test_lib:close_port(Port),
    ssl_test_lib:close(Pid);
close(Client, Server)  ->
    ssl_test_lib:close(Server),
    ssl_test_lib:close(Client).

%% Work around OpenSSL bug, apparently the same bug as we had fixed in
%% 11629690ba61f8e0c93ef9b2b6102fd279825977
new_ca(FileName, CA, OwnCa) ->
    {ok, P1} = file:read_file(CA),
    E1 = public_key:pem_decode(P1),
    {ok, P2} = file:read_file(OwnCa),
    E2 = public_key:pem_decode(P2),
    Pem = public_key:pem_encode(E2 ++E1),
    file:write_file(FileName,  Pem),
    FileName.
