visit
// cmd/server/main.go
package main
import (
"github.com/port-scanner/cmd/server/cfg"
"github.com/port-scanner/cmd/server/router"
"github.com/port-scanner/pkg/gracefulshutdown"
"github.com/port-scanner/pkg/server"
)
func main() {
cfg := cfg.New()
// configure server instance
srv := server.
New().
WithAddr(cfg.GetAPIPort()).
WithRouter(router.Get(cfg)).
WithErrLogger(cfg.Errlog)
// start server in separate goroutine so that it doesn't block graceful
// shutdown handler
go func() {
cfg.Infolog.Printf("starting server at %s", cfg.GetAPIPort())
if err := srv.Start(); err != nil {
cfg.Errlog.Printf("starting server: %s", err)
}
}()
// initiate graceful shutdown handler that will listen to api crash signals
// and perform cleanup
gracefulshutdown.Init(cfg.Errlog, func() {
if err := srv.Close(); err != nil {
cfg.Errlog.Printf("closing server: %s", err)
}
})
}
I will only have one
GET open-ports
end-point. Let’s review the handler:// cmd/server/handlers/scanports.go
package handlers
import (
"encoding/json"
"net/http"
"strconv"
"github.com/julienschmidt/httprouter"
"github.com/port-scanner/pkg/portscanner"
"github.com/port-scanner/pkg/reqvalidator"
)
type openPorts struct {
FromPort int `json:"from_port"`
ToPort int `json:"to_port"`
Domain string `json:"domain"`
OpenPorts []int `json:"open_ports"`
}
func ScanOpenPorts(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
defer r.Body.Close()
w.Header().Add("Content-Type", "application/json")
queryValues := r.URL.Query()
domain := queryValues.Get("domain")
toPort := queryValues.Get("toPort")
v := reqvalidator.New()
v.Required("domain", domain)
v.Required("toPort", toPort)
v.ValidDecimalString("toPort", toPort)
if !v.Valid() {
w.WriteHeader(http.StatusForbidden)
w.Write(v.GetErrResp())
return
}
// safe to skip error check here as validator above has done that already
port, _ := strconv.Atoi(toPort)
op := portscanner.
New(domain).
ScanTo(port)
report := openPorts{
FromPort: 0,
ToPort: port,
Domain: domain,
OpenPorts: op,
}
resp, err := json.Marshal(report)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
w.Write(resp)
}
I expect to receive few query params in this handler: domain and toPort. Both params are mandatory and toPort must be valid decimal string representation. I have validator in place to check exactly that. The toPort number will represent the port limit that my scanner will check and also define the size of worker pool. This will allow the scanner to perform all calls at once.
Before going any further, let’s review the validator:// pkg/reqvalidator/reqvalidator.go
package reqvalidator
import (
"encoding/json"
"strconv"
"strings"
)
type validationerrors map[string][]string
func (ve validationerrors) Add(field, message string) {
ve[field] = append(ve[field], message)
}
type errResp struct {
Result string
Cause string
InvalidFields validationerrors
}
type ReqValidator struct {
Errors validationerrors
}
func New() ReqValidator {
return ReqValidator{
validationerrors(map[string][]string{}),
}
}
func (rv ReqValidator) Required(field, value string) {
if strings.TrimSpace(value) == "" {
rv.Errors.Add(field, "can't be blank")
}
}
func (rv ReqValidator) ValidDecimalString(field, value string) {
_, err := strconv.Atoi(value)
if err != nil {
rv.Errors.Add(field, "invalid decimal string")
}
}
func (rv ReqValidator) Valid() bool {
return len(rv.Errors) == 0
}
func (rv ReqValidator) GetErrResp() []byte {
er := errResp{
Result: "ERROR",
Cause: "INVALID_REQUEST",
InvalidFields: rv.Errors,
}
b, _ := json.Marshal(er)
return b
}
Required
and ValidDecimalString
are my basic validations. Given the field name and value, I keep track of error messages if validation fails. Call to Valid
checks to see if there're any errors and if errors are present I can retrieve them with GetErrResp
and send that information to consumer.Now let’s review the scanner:// pkg/portscanner/portscanner.go
package portscanner
import (
"fmt"
"net"
"time"
)
type Scanner struct {
domain string
}
func New(domain string) Scanner {
return Scanner{domain}
}
func (s Scanner) ScanTo(toPort int) (openPorts []int) {
jobs := make(chan int)
results := make(chan int)
for i := 0; i < toPort; i++ {
go worker(s.domain, jobs, results)
}
go func() {
for i := 1; i <= toPort; i++ {
jobs <- i
}
}()
for i := 1; i <= toPort; i++ {
port := <-results
if port != 0 {
openPorts = append(openPorts, port)
}
}
close(jobs)
close(results)
return openPorts
}
func worker(domain string, jobs <-chan int, results chan<- int) {
for j := range jobs {
_, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", domain, j), 2*time.Second)
if err != nil {
results <- 0
continue
}
results <- j
}
}
The
ScanTo
function does the heavy lifting here. I first create jobs and results channels. I then spawn workers. Each worker will be receiving port over jobs channel, scan that port and return result of the TCP call over results channels. I use net.DialTimeout
so that worker does not wait for response for longer than 2 seconds. When last port is finished scanning I return all open ports to consumer.You can already try interacting with this tool by running:$ go run cmd/server/main.go
And issuing curl commands for
localhost:8080/open-ports
I hope you have learned something useful. In next post I'll go through steps needed to containerise this app, deploy and scale using k8s. You can find the source code .