local nmap = require "nmap"
local shortport = require "shortport"
local sslcert = require "sslcert"
local stdnse = require "stdnse"
local string = require "string"
local table = require "table"
local tableaux = require "tableaux"
local tls = require "tls"
local listop = require "listop"

description = [[
    Checking if the server is using SCSV encryption signatures
]]

-- @output
--PORT    STATE SERVICE
--443/tcp open  https
--| altx-soft_TLS_FALLBACK_SCSV: 
--|_  TLS_FALLBACK_SCSV: SCSV signature encryption is not supported

author = "Alex Petrov"

	
local use_portrule = stdnse.get_script_args("stdnse.use_portrule") or false

if (use_portrule == "true") then
	portrule = function(host,port) return shortport.ssl(host, port) or sslcert.getPrepareTLSWithoutReconnect(port) end
else portrule = function(host,port) return true end
end

local CHUNK_SIZE = 64

local function ctx_log(level, protocol, fmt, ...)
    return stdnse.print_debug(level, "(%s) " .. fmt, protocol, ...)
end

local function sorted_keys(t)
    local ret = {}
    for k, _ in pairs(t) do
      ret[#ret+1] = k
    end
    table.sort(ret)
    return ret
end

local function remove(t, e)
    for i, v in ipairs(t) do
      if v == e then
        table.remove(t, i)
        return i
      end
    end
    return nil
end

local function try_params(host, port, t)
    local timeout = ((host.times and host.times.timeout) or 5) * 1000 + 5000
  
    -- Create socket.
    local status, sock, err
    local specialized = sslcert.getPrepareTLSWithoutReconnect(port)
    if specialized then
      status, sock = specialized(host, port)
      if not status then
        ctx_log(1, t.protocol, "Can't connect: %s", sock)
        return nil
      end
    else
      sock = nmap.new_socket()
      sock:set_timeout(timeout)
      status, err = sock:connect(host, port)
      if not status then
        ctx_log(1, t.protocol, "Can't connect: %s", err)
        sock:close()
        return nil
      end
    end
  
    sock:set_timeout(timeout)
  
    -- Send request.
    local req = tls.client_hello(t)
    status, err = sock:send(req)
    if not status then
      ctx_log(1, t.protocol, "Can't send: %s", err)
      sock:close()
      return nil
    end
  
    -- Read response.
    local buffer = ""
    local i = 1
    while true do
      status, buffer, err = tls.record_buffer(sock, buffer, i)
      if not status then
        ctx_log(1, t.protocol, "Couldn't read a TLS record: %s", err)
        return nil
      end
      -- Parse response.
      local record
      i, record = tls.record_read(buffer, i)
      if record and record.type == "alert" and record.body[1].level == "warning" then
        ctx_log(1, t.protocol, "Ignoring warning: %s", record.body[1].description)
        -- Try again.
      elseif record then
        sock:close()
        return record
      end
    end
end

local function in_chunks(t, size)
    local ret = {}
    for i = 1, #t, size do
      local chunk = {}
      for j = i, i + size - 1 do
        chunk[#chunk+1] = t[j]
      end
      ret[#ret+1] = chunk
    end
    return ret
end

local function base_extensions(host)
    local tlsname = tls.servername(host)
    return {
      -- Claim to support common elliptic curves
      ["elliptic_curves"] = tls.EXTENSION_HELPERS["elliptic_curves"](tls.DEFAULT_ELLIPTIC_CURVES),
      -- Enable SNI if a server name is available
      ["server_name"] = tlsname and tls.EXTENSION_HELPERS["server_name"](tlsname),
    }
end

local function remove_high_byte_ciphers(t)
    local output = {}
    for i, v in ipairs(t) do
      if tls.CIPHERS[v] <= 255 then
        output[#output+1] = v
      end
    end
    return output
end

local function find_ciphers_group(host, port, protocol, group)
    local name, protocol_worked, record, results
    results = {}
    local t = {
      ["protocol"] = protocol,
      ["extensions"] = base_extensions(host),
    }

    protocol_worked = false
    while (next(group)) do
      t["ciphers"] = group
  
      record = try_params(host, port, t)
  
      if record == nil then
        if protocol_worked then
          ctx_log(2, protocol, "%d ciphers rejected. (No handshake)", #group)
        else
          ctx_log(1, protocol, "%d ciphers and/or protocol rejected. (No handshake)", #group)
        end
        break
      elseif record["protocol"] ~= protocol or record["body"][1]["protocol"] and record.body[1].protocol ~= protocol then
        ctx_log(1, protocol, "Protocol rejected.")
        protocol_worked = nil
        break
      elseif record["type"] == "alert" and record["body"][1]["description"] == "handshake_failure" then
        protocol_worked = true
        ctx_log(2, protocol, "%d ciphers rejected.", #group)
        break
      elseif record["type"] ~= "handshake" or record["body"][1]["type"] ~= "server_hello" then
        ctx_log(2, protocol, "Unexpected record received.")
        break
      else
        protocol_worked = true
        name = record["body"][1]["cipher"]
        ctx_log(1, protocol, "Cipher %s chosen.", name)
        if not remove(group, name) then
          ctx_log(1, protocol, "chose cipher %s that was not offered.", name)
          ctx_log(1, protocol, "removing high-byte ciphers and trying again.")
          local size_before = #group
          group = remove_high_byte_ciphers(group)
          ctx_log(1, protocol, "removed %d high-byte ciphers.", size_before - #group)
          if #group == size_before then
            break
          end
        else
          table.insert(results, name)
          break
        end
      end
    end
    return results, protocol_worked
end

local _ciphers = listop.filter(
  function(x) return string.find(x, "_",1,true) end,
  sorted_keys(tls.CIPHERS)
)

local function find_ciphers(host, port, protocol)
    local name, protocol_worked, results, chunk
    local ciphers = in_chunks(_ciphers, CHUNK_SIZE)
  
    results = {}
  
    for _, group in ipairs(ciphers) do
      chunk, protocol_worked = find_ciphers_group(host, port, protocol, group)
      if protocol_worked == nil then return nil end
      for _, name in ipairs(chunk) do
        table.insert(results, name)
      end

      if protocol_worked and next(results) then return results end
    end
    return results
end

local function check_fallback_scsv(host, port, protocol, ciphers)
    local results = {}
    local t = {
      ["protocol"] = protocol,
      ["extensions"] = base_extensions(host),
    }
  
    t["ciphers"] = tableaux.tcopy(ciphers)
    t.ciphers[#t.ciphers+1] = "TLS_FALLBACK_SCSV"

    if not tls.TLS_ALERT_REGISTRY["inappropriate_fallback"] then
      tls.CIPHERS["TLS_FALLBACK_SCSV"] = 0x5600
      tls.TLS_ALERT_REGISTRY["inappropriate_fallback"] = 86
    end
  
    local record = try_params(host, port, t)
    tls.CIPHERS["TLS_FALLBACK_SCSV"] = nil
  
    if record and record["type"] == "alert" and record["body"][1]["description"] == "inappropriate_fallback" then
      ctx_log(2, protocol, "TLS_FALLBACK_SCSV rejected properly.")
      return true
    end
    return false
end

action = function(host, port)

    local out_table = stdnse.output_table()
    local name_ciphers_tlslib = {"SSLv3","TLSv1.0", "TLSv1.1", "TLSv1.2"} 
    
    for i = 1 , #name_ciphers_tlslib do 
        local ciphers = find_ciphers(host, port, name_ciphers_tlslib[i])
        
        if ciphers ~= nil then 
            if not check_fallback_scsv(host, port, name_ciphers_tlslib[i], ciphers) then
                out_table["TLS_FALLBACK_SCSV"] = "SCSV signature encryption is not supported"
                return out_table
            end
            return nil
        end 
    end
    return nil
end