move upstream project (https://github.com/itsxaos/stub) into subdir
This commit is contained in:
parent
e36782c04b
commit
99a5e07224
11 changed files with 0 additions and 0 deletions
260
caddyapp/server.go
Normal file
260
caddyapp/server.go
Normal file
|
|
@ -0,0 +1,260 @@
|
|||
package stub
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/caddyserver/caddy/v2"
|
||||
"github.com/miekg/dns"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// A DNS Query coming in from outside
|
||||
type query struct {
|
||||
w dns.ResponseWriter
|
||||
r *dns.Msg
|
||||
}
|
||||
|
||||
type key struct {
|
||||
Type dns.Type
|
||||
Name string
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
// the address & port on which to serve DNS for the challenge
|
||||
Address caddy.NetworkAddress `json:"address,omitempty"`
|
||||
|
||||
// Statically configured records to serve
|
||||
Records map[key][]dns.RR `json:"records,omitempty"`
|
||||
|
||||
logger *zap.Logger // set by App.start()
|
||||
ctx *caddy.Context // set by App.start()
|
||||
shutdown chan struct{} // set by App.start()
|
||||
requests chan request // set by App.start()
|
||||
|
||||
dns_server *dns.Server // set by start_stop_server()
|
||||
queries chan query // set by start_stop_server()
|
||||
|
||||
}
|
||||
|
||||
func rr_key(record dns.RR) key {
|
||||
return key{
|
||||
Type: dns.Type(record.Header().Rrtype),
|
||||
Name: strings.ToLower(record.Header().Name),
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) insert_record(record dns.RR) {
|
||||
key := rr_key(record)
|
||||
current, exists := srv.Records[key]
|
||||
if exists {
|
||||
// TODO: de-duplicate?
|
||||
srv.Records[key] = append(current, record)
|
||||
} else {
|
||||
srv.Records[key] = []dns.RR{record}
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) delete_record(record dns.RR) bool {
|
||||
key := rr_key(record)
|
||||
current, exists := srv.Records[key]
|
||||
if exists {
|
||||
filtered := []dns.RR{}
|
||||
as_string := record.String()
|
||||
for _, rec := range current {
|
||||
// TODO: see if there is a more efficient way to compare these
|
||||
// just rec != record does not seem to work, might be doing ptr eq
|
||||
if rec.String() != as_string {
|
||||
filtered = append(filtered, rec)
|
||||
}
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
delete(srv.Records, key)
|
||||
} else {
|
||||
srv.Records[key] = filtered
|
||||
}
|
||||
return len(filtered) < len(current)
|
||||
} else {
|
||||
// doesn't exist, nothing to delete
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// This is the "main loop" of the DNS server
|
||||
// To avoid having to synchronize access to the records map, it is owned
|
||||
// exclusively by this loop, and the methods it calls.
|
||||
// All DNS queries coming from outside, as well as all requests to create
|
||||
// or delete DNS records coming from within the process are serialized by
|
||||
// the select statement.
|
||||
func (srv *Server) main() {
|
||||
srv.logger.Debug(
|
||||
"main loop running",
|
||||
zap.Int("record_count", len(srv.Records)),
|
||||
)
|
||||
for {
|
||||
select {
|
||||
case r := <-srv.requests:
|
||||
srv.handle_request(r)
|
||||
case q := <-srv.queries:
|
||||
srv.handle_query(q)
|
||||
case <-srv.shutdown:
|
||||
srv.logger.Debug("stopping main loop")
|
||||
if srv.dns_server != nil {
|
||||
srv.dns_server.Shutdown()
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) handle_request(r request) {
|
||||
var count_field zap.Field
|
||||
if r.append {
|
||||
for _, record := range r.records {
|
||||
srv.insert_record(record)
|
||||
}
|
||||
count_field = zap.Int("appended_records", len(r.records))
|
||||
} else {
|
||||
count := 0
|
||||
for _, record := range r.records {
|
||||
if srv.delete_record(record) {
|
||||
count += 1
|
||||
}
|
||||
}
|
||||
count_field = zap.Int("deleted_records", count)
|
||||
}
|
||||
|
||||
srv.logger.Debug("handled", zap.Object("request", r), count_field)
|
||||
|
||||
r.responder <- srv.start_stop_server()
|
||||
}
|
||||
|
||||
func (srv *Server) start_stop_server() error {
|
||||
if srv.queries == nil {
|
||||
srv.queries = make(chan query)
|
||||
}
|
||||
if len(srv.Records) == 0 {
|
||||
if srv.dns_server != nil {
|
||||
srv.logger.Debug("no more records to serve, shutting down server")
|
||||
err := srv.dns_server.Shutdown()
|
||||
srv.dns_server = nil
|
||||
return err
|
||||
}
|
||||
srv.logger.Debug("no records to serve")
|
||||
return nil
|
||||
} else {
|
||||
if srv.dns_server == nil {
|
||||
conn, err := srv.bind()
|
||||
if err != nil {
|
||||
srv.logger.Error(
|
||||
"failed to bind",
|
||||
zap.Stringer("address", srv.Address),
|
||||
zap.Error(err),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// spawn the server
|
||||
handler := make_proxy(srv.queries)
|
||||
server := &dns.Server{
|
||||
PacketConn: conn,
|
||||
Net: "udp",
|
||||
Handler: handler,
|
||||
TsigSecret: nil,
|
||||
}
|
||||
srv.logger.Debug(
|
||||
"starting server",
|
||||
zap.Int("record_count", len(srv.Records)),
|
||||
)
|
||||
go srv.serve(server)
|
||||
|
||||
// store the server for shutdown later
|
||||
srv.dns_server = server
|
||||
return nil
|
||||
}
|
||||
srv.logger.Debug(
|
||||
"server already running",
|
||||
zap.Int("record_count", len(srv.Records)),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) bind() (net.PacketConn, error) {
|
||||
conn, err := srv.Address.Listen(srv.ctx, 0, net.ListenConfig{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pkt_conn := conn.(net.PacketConn)
|
||||
if pkt_conn == nil {
|
||||
return nil, errors.New("invalid address")
|
||||
}
|
||||
srv.logger.Debug("bound to socket", zap.Stringer("address", srv.Address))
|
||||
return pkt_conn, nil
|
||||
}
|
||||
|
||||
func (srv *Server) handle_query(q query) {
|
||||
// dns.DefaultMsgAcceptFunc already checks that the query is fairly
|
||||
// reasonable.
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(q.r)
|
||||
|
||||
reject_and_log := func(code int, reason string) {
|
||||
m.Rcode = code
|
||||
m.Answer = []dns.RR{}
|
||||
srv.logger.Debug(
|
||||
"rejecting query",
|
||||
zap.Stringer("address", q.w.RemoteAddr()),
|
||||
zap.String("reason", reason),
|
||||
zap.Object("response", LoggableDNSMsg{m}),
|
||||
)
|
||||
q.w.WriteMsg(m)
|
||||
}
|
||||
|
||||
qstn := q.r.Question[0]
|
||||
if !(qstn.Qclass == dns.ClassINET || qstn.Qclass == dns.ClassANY) {
|
||||
// TODO: consider just not worrying about this
|
||||
reject_and_log(dns.RcodeNotImplemented, "invalid class")
|
||||
return
|
||||
}
|
||||
// queries may be wAcKY casE
|
||||
// https://datatracker.ietf.org/doc/html/draft-vixie-dnsext-dns0x20-00
|
||||
key := key{
|
||||
Type: dns.Type(qstn.Qtype),
|
||||
Name: strings.ToLower(qstn.Name),
|
||||
}
|
||||
records, exists := srv.Records[key]
|
||||
if !exists {
|
||||
reject_and_log(dns.RcodeNameError, "no such record")
|
||||
return
|
||||
}
|
||||
|
||||
m.Authoritative = true
|
||||
m.Answer = records
|
||||
|
||||
srv.logger.Debug(
|
||||
"answering query",
|
||||
zap.Stringer("address", q.w.RemoteAddr()),
|
||||
zap.Object("response", LoggableDNSMsg{m}),
|
||||
)
|
||||
q.w.WriteMsg(m)
|
||||
}
|
||||
|
||||
func (srv *Server) serve(server *dns.Server) {
|
||||
err := server.ActivateAndServe()
|
||||
if err != nil {
|
||||
srv.logger.Error("dns.ActivateAndServe failed", zap.Error(err))
|
||||
} else {
|
||||
srv.logger.Debug("server terminated successfully")
|
||||
}
|
||||
}
|
||||
|
||||
// dns.HandlerFunc that forwards every query into a channel
|
||||
func make_proxy(sink chan query) dns.HandlerFunc {
|
||||
return func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
q := query{w, r}
|
||||
sink <- q
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue