1
package doh
2

3
import (
4
	"encoding/base64"
5
	"encoding/json"
6
	"io/ioutil"
7
	"net"
8
	"net/http"
9
	"strings"
10

11
	"github.com/miekg/dns"
12
)
13

14
// HandleWireFormat handle wire format
15
func HandleWireFormat(handle func(*dns.Msg) *dns.Msg) func(http.ResponseWriter, *http.Request) {
16 1
	return func(w http.ResponseWriter, r *http.Request) {
17 1
		var (
18 1
			buf []byte
19 1
			err error
20 1
		)
21

22 1
		switch r.Method {
23 1
		case http.MethodGet:
24 1
			buf, err = base64.RawURLEncoding.DecodeString(r.URL.Query().Get("dns"))
25 1
			if len(buf) == 0 || err != nil {
26 0
				http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
27 0
				return
28
			}
29 1
		case http.MethodPost:
30 1
			if r.Header.Get("Content-Type") != "application/dns-message" {
31 1
				http.Error(w, http.StatusText(http.StatusUnsupportedMediaType), http.StatusUnsupportedMediaType)
32 1
				return
33
			}
34

35 1
			buf, err = ioutil.ReadAll(r.Body)
36 1
			if err != nil {
37 0
				http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
38 0
				return
39
			}
40 1
			defer r.Body.Close()
41 1
		default:
42 1
			http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
43 1
			return
44
		}
45

46 1
		req := new(dns.Msg)
47 1
		if err := req.Unpack(buf); err != nil {
48 1
			http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
49 1
			return
50
		}
51

52 1
		msg := handle(req)
53 1
		if msg == nil {
54 0
			http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
55 0
			return
56
		}
57

58 1
		packed, err := msg.Pack()
59 1
		if err != nil {
60 0
			http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
61 0
			return
62
		}
63

64 1
		w.Header().Set("Server", "SDNS")
65 1
		w.Header().Set("Content-Type", "application/dns-message")
66

67 1
		_, _ = w.Write(packed)
68
	}
69
}
70

71
// HandleJSON handle json format
72
func HandleJSON(handle func(*dns.Msg) *dns.Msg) func(http.ResponseWriter, *http.Request) {
73 1
	return func(w http.ResponseWriter, r *http.Request) {
74 1
		name := r.URL.Query().Get("name")
75 1
		if name == "" {
76 1
			http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
77 1
			return
78
		}
79 1
		name = dns.Fqdn(name)
80

81 1
		qtype := ParseQTYPE(r.URL.Query().Get("type"))
82 1
		if qtype == dns.TypeNone {
83 0
			http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
84 0
			return
85
		}
86

87 1
		req := new(dns.Msg)
88 1
		req.SetQuestion(name, qtype)
89 1
		req.AuthenticatedData = true
90

91 1
		if r.URL.Query().Get("cd") == "true" {
92 1
			req.CheckingDisabled = true
93
		}
94

95 1
		opt := &dns.OPT{
96 1
			Hdr: dns.RR_Header{
97 1
				Name:   ".",
98 1
				Class:  dns.DefaultMsgSize,
99 1
				Rrtype: dns.TypeOPT,
100 1
			},
101
		}
102

103 1
		if r.URL.Query().Get("do") == "true" {
104 1
			opt.SetDo()
105
		}
106

107 1
		if ecs := r.URL.Query().Get("edns_client_subnet"); ecs != "" {
108 1
			_, subnet, err := net.ParseCIDR(ecs)
109 1
			if err != nil {
110 0
				http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
111 0
				return
112
			}
113

114 1
			mask, bits := subnet.Mask.Size()
115 1
			var af uint16
116 1
			if bits == 32 {
117 1
				af = 1
118 1
			} else {
119 0
				af = 2
120
			}
121

122 1
			opt.Option = []dns.EDNS0{
123 1
				&dns.EDNS0_SUBNET{
124 1
					Code:          dns.EDNS0SUBNET,
125 1
					Family:        af,
126 1
					SourceNetmask: uint8(mask),
127 1
					SourceScope:   0,
128 1
					Address:       subnet.IP,
129 1
				},
130
			}
131
		}
132

133 1
		req.Extra = append(req.Extra, opt)
134

135 1
		msg := handle(req)
136 1
		if msg == nil {
137 0
			http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
138 0
			return
139
		}
140

141 1
		json, err := json.Marshal(NewMsg(msg))
142 1
		if err != nil {
143 0
			http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
144 0
			return
145
		}
146

147 1
		w.Header().Set("Server", "SDNS")
148

149 1
		if strings.Contains(r.Header.Get("Accept"), "text/html") {
150 1
			w.Header().Set("Content-Type", "application/x-javascript")
151 1
		} else {
152 1
			w.Header().Set("Content-Type", "application/dns-json")
153
		}
154

155 1
		_, _ = w.Write(json)
156
	}
157
}

Read our documentation on viewing source code .

Loading