From 6fe5c44c1c5839f56330d47f8fcc999a4ad54277 Mon Sep 17 00:00:00 2001 From: haiwu <54203997+Chinoholo0807@users.noreply.github.com> Date: Tue, 26 Sep 2023 22:48:00 +0800 Subject: [PATCH] feat: support radius accounting request (#2362) * feat: add radius server * feat: parse org from packet * feat: add comment * feat: support radius accounting * feat: change log * feat: add copyright --- conf/app.conf | 1 + object/ormer.go | 5 ++ object/radius.go | 124 ++++++++++++++++++++++++++++++++++++++++++ radius/server.go | 64 ++++++++++++++++++---- radius/server_test.go | 12 +--- radius/util.go | 61 +++++++++++++++------ 6 files changed, 229 insertions(+), 38 deletions(-) create mode 100644 object/radius.go diff --git a/conf/app.conf b/conf/app.conf index b68bc4cb..ac81c7fa 100644 --- a/conf/app.conf +++ b/conf/app.conf @@ -21,6 +21,7 @@ isDemoMode = false batchSize = 100 ldapServerPort = 389 radiusServerPort = 1812 +radiusSecret = "secret" quota = {"organization": -1, "user": -1, "application": -1, "provider": -1} logConfig = {"filename": "logs/casdoor.log", "maxdays":99999, "perm":"0770"} initDataFile = "./init_data.json" \ No newline at end of file diff --git a/object/ormer.go b/object/ormer.go index 87d27a8e..15e75aa8 100644 --- a/object/ormer.go +++ b/object/ormer.go @@ -316,6 +316,11 @@ func (a *Ormer) createTable() { panic(err) } + err = a.Engine.Sync2(new(RadiusAccounting)) + if err != nil { + panic(err) + } + err = a.Engine.Sync2(new(PermissionRule)) if err != nil { panic(err) diff --git a/object/radius.go b/object/radius.go new file mode 100644 index 00000000..b760fbe6 --- /dev/null +++ b/object/radius.go @@ -0,0 +1,124 @@ +package object + +import ( + "fmt" + "time" + + "github.com/casdoor/casdoor/util" + "github.com/xorm-io/core" +) + +// https://www.cisco.com/c/en/us/td/docs/ios-xml/ios/sec_usr_radatt/configuration/xe-16/sec-usr-radatt-xe-16-book/sec-rad-ov-ietf-attr.html +// https://support.huawei.com/enterprise/zh/doc/EDOC1000178159/35071f9a +type RadiusAccounting struct { + Owner string `xorm:"varchar(100) notnull pk" json:"owner"` + Name string `xorm:"varchar(100) notnull pk" json:"name"` + CreatedTime time.Time `json:"createdTime"` + + Username string `xorm:"index" json:"username"` + ServiceType int64 `json:"serviceType"` // e.g. LoginUser (1) + + NasId string `json:"nasId"` // String identifying the network access server originating the Access-Request. + NasIpAddr string `json:"nasIpAddr"` // e.g. "192.168.0.10" + NasPortId string `json:"nasPortId"` // Contains a text string which identifies the port of the NAS that is authenticating the user. e.g."eth.0" + NasPortType int64 `json:"nasPortType"` // Indicates the type of physical port the network access server is using to authenticate the user. e.g.Ethernet(15) + NasPort int64 `json:"nasPort"` // Indicates the physical port number of the network access server that is authenticating the user. e.g. 233 + + FramedIpAddr string `json:"framedIpAddr"` // Indicates the IP address to be configured for the user by sending the IP address of a user to the RADIUS server. + FramedIpNetmask string `json:"framedIpNetmask"` // Indicates the IP netmask to be configured for the user when the user is using a device on a network. + + AcctSessionId string `xorm:"index" json:"acctSessionId"` + AcctSessionTime int64 `json:"acctSessionTime"` // Indicates how long (in seconds) the user has received service. + AcctInputTotal int64 `json:"acctInputTotal"` + AcctOutputTotal int64 `json:"acctOutputTotal"` + AcctInputPackets int64 `json:"acctInputPackets"` // Indicates how many packets have been received from the port over the course of this service being provided to a framed user. + AcctOutputPackets int64 `json:"acctOutputPackets"` // Indicates how many packets have been sent to the port in the course of delivering this service to a framed user. + AcctTerminateCause int64 `json:"acctTerminateCause"` // e.g. Lost-Carrier (2) + LastUpdate time.Time `json:"lastUpdate"` + AcctStartTime time.Time `xorm:"index" json:"acctStartTime"` + AcctStopTime time.Time `xorm:"index" json:"acctStopTime"` +} + +func (ra *RadiusAccounting) GetId() string { + return util.GetId(ra.Owner, ra.Name) +} + +func getRadiusAccounting(owner, name string) (*RadiusAccounting, error) { + if owner == "" || name == "" { + return nil, nil + } + ra := RadiusAccounting{Owner: owner, Name: name} + existed, err := ormer.Engine.Get(&ra) + if err != nil { + return nil, err + } + if existed { + return &ra, nil + } else { + return nil, nil + } +} + +func getPaginationRadiusAccounting(owner, field, value, sortField, sortOrder string, offset, limit int) ([]*RadiusAccounting, error) { + ras := []*RadiusAccounting{} + session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) + err := session.Find(&ras) + if err != nil { + return ras, err + } + return ras, nil +} + +func GetRadiusAccounting(id string) (*RadiusAccounting, error) { + owner, name := util.GetOwnerAndNameFromId(id) + return getRadiusAccounting(owner, name) +} + +func GetRadiusAccountingBySessionId(sessionId string) (*RadiusAccounting, error) { + ras, err := getPaginationRadiusAccounting("", "acct_session_id", sessionId, "created_time", "desc", 0, 1) + if err != nil { + return nil, err + } + if len(ras) == 0 { + return nil, nil + } + return ras[0], nil +} + +func AddRadiusAccounting(ra *RadiusAccounting) error { + _, err := ormer.Engine.Insert(ra) + return err +} + +func DeleteRadiusAccounting(ra *RadiusAccounting) error { + _, err := ormer.Engine.ID(core.PK{ra.Owner, ra.Name}).Delete(&RadiusAccounting{}) + return err +} + +func UpdateRadiusAccounting(id string, ra *RadiusAccounting) error { + owner, name := util.GetOwnerAndNameFromId(id) + _, err := ormer.Engine.ID(core.PK{owner, name}).Update(ra) + return err +} + +func InterimUpdateRadiusAccounting(oldRa *RadiusAccounting, newRa *RadiusAccounting, stop bool) error { + if oldRa.AcctSessionId != newRa.AcctSessionId { + return fmt.Errorf("AcctSessionId is not equal, newRa = %s, oldRa = %s", newRa.AcctSessionId, oldRa.AcctSessionId) + } + oldRa.AcctInputTotal = newRa.AcctInputTotal + oldRa.AcctOutputTotal = newRa.AcctOutputTotal + oldRa.AcctInputPackets = newRa.AcctInputPackets + oldRa.AcctOutputPackets = newRa.AcctOutputPackets + oldRa.AcctSessionTime = newRa.AcctSessionTime + if stop { + oldRa.AcctStopTime = newRa.AcctStopTime + if oldRa.AcctStopTime.IsZero() { + oldRa.AcctStopTime = time.Now() + } + oldRa.AcctTerminateCause = newRa.AcctTerminateCause + } else { + oldRa.LastUpdate = time.Now() + } + + return UpdateRadiusAccounting(oldRa.GetId(), oldRa) +} diff --git a/radius/server.go b/radius/server.go index ce66726c..5a25dc2c 100644 --- a/radius/server.go +++ b/radius/server.go @@ -15,24 +15,27 @@ package radius import ( + "fmt" "log" "github.com/casdoor/casdoor/conf" "github.com/casdoor/casdoor/object" "layeh.com/radius" "layeh.com/radius/rfc2865" + "layeh.com/radius/rfc2866" ) // https://support.huawei.com/enterprise/zh/doc/EDOC1000178159/35071f9a#tab_3 func StartRadiusServer() { + secret := conf.GetConfigString("radiusSecret") server := radius.PacketServer{ Addr: "0.0.0.0:" + conf.GetConfigString("radiusServerPort"), Handler: radius.HandlerFunc(handlerRadius), - SecretSource: radius.StaticSecretSource([]byte(`secret`)), + SecretSource: radius.StaticSecretSource([]byte(secret)), } log.Printf("Starting Radius server on %s", server.Addr) if err := server.ListenAndServe(); err != nil { - log.Printf("StartRadiusServer() failed, err = %s", err.Error()) + log.Printf("StartRadiusServer() failed, err = %v", err) } } @@ -40,6 +43,8 @@ func handlerRadius(w radius.ResponseWriter, r *radius.Request) { switch r.Code { case radius.CodeAccessRequest: handleAccessRequest(w, r) + case radius.CodeAccountingRequest: + handleAccountingRequest(w, r) default: log.Printf("radius message, code = %d", r.Code) } @@ -48,20 +53,57 @@ func handlerRadius(w radius.ResponseWriter, r *radius.Request) { func handleAccessRequest(w radius.ResponseWriter, r *radius.Request) { username := rfc2865.UserName_GetString(r.Packet) password := rfc2865.UserPassword_GetString(r.Packet) - organization := parseOrganization(r.Packet) - code := radius.CodeAccessAccept - - log.Printf("username=%v, password=%v, code=%v, org=%v", username, password, code, organization) + organization := rfc2865.Class_GetString(r.Packet) + log.Printf("handleAccessRequest() username=%v, org=%v, password=%v", username, organization, password) if organization == "" { - code = radius.CodeAccessReject - w.Write(r.Response(code)) + w.Write(r.Response(radius.CodeAccessReject)) return } _, msg := object.CheckUserPassword(organization, username, password, "en") if msg != "" { - code = radius.CodeAccessReject - w.Write(r.Response(code)) + w.Write(r.Response(radius.CodeAccessReject)) return } - w.Write(r.Response(code)) + w.Write(r.Response(radius.CodeAccessAccept)) +} + +func handleAccountingRequest(w radius.ResponseWriter, r *radius.Request) { + statusType := rfc2866.AcctStatusType_Get(r.Packet) + username := rfc2865.UserName_GetString(r.Packet) + organization := rfc2865.Class_GetString(r.Packet) + log.Printf("handleAccountingRequest() username=%v, org=%v, statusType=%v", username, organization, statusType) + w.Write(r.Response(radius.CodeAccountingResponse)) + var err error + defer func() { + if err != nil { + log.Printf("handleAccountingRequest() failed, err = %v", err) + } + }() + switch statusType { + case rfc2866.AcctStatusType_Value_Start: + // Start an accounting session + ra := GetAccountingFromRequest(r) + err = object.AddRadiusAccounting(ra) + case rfc2866.AcctStatusType_Value_InterimUpdate, rfc2866.AcctStatusType_Value_Stop: + // Interim update to an accounting session | Stop an accounting session + var ( + newRa = GetAccountingFromRequest(r) + oldRa *object.RadiusAccounting + ) + oldRa, err = object.GetRadiusAccountingBySessionId(newRa.AcctSessionId) + if err != nil { + return + } + if oldRa == nil { + if err = object.AddRadiusAccounting(newRa); err != nil { + return + } + } + stop := statusType == rfc2866.AcctStatusType_Value_Stop + err = object.InterimUpdateRadiusAccounting(oldRa, newRa, stop) + case rfc2866.AcctStatusType_Value_AccountingOn, rfc2866.AcctStatusType_Value_AccountingOff: + // By default, no Accounting-On or Accounting-Off messages are sent (no acct-on-off). + default: + err = fmt.Errorf("unsupport statusType = %v", statusType) + } } diff --git a/radius/server_test.go b/radius/server_test.go index d96c7380..0955edd3 100644 --- a/radius/server_test.go +++ b/radius/server_test.go @@ -29,11 +29,7 @@ func TestAccessRequestRejected(t *testing.T) { packet := radius.New(radius.CodeAccessRequest, []byte(`secret`)) rfc2865.UserName_SetString(packet, "admin") rfc2865.UserPassword_SetString(packet, "12345") - vsa, err := radius.NewVendorSpecific(OrganizationVendorID, []byte("built-in")) - if err != nil { - t.Fatal(err) - } - packet.Add(rfc2865.VendorSpecific_Type, vsa) + rfc2865.Class_SetString(packet, "built-in") response, err := radius.Exchange(context.Background(), packet, "localhost:1812") if err != nil { t.Fatal(err) @@ -47,11 +43,7 @@ func TestAccessRequestAccepted(t *testing.T) { packet := radius.New(radius.CodeAccessRequest, []byte(`secret`)) rfc2865.UserName_SetString(packet, "admin") rfc2865.UserPassword_SetString(packet, "123") - vsa, err := radius.NewVendorSpecific(OrganizationVendorID, []byte("built-in")) - if err != nil { - t.Fatal(err) - } - packet.Add(rfc2865.VendorSpecific_Type, vsa) + rfc2865.Class_SetString(packet, "built-in") response, err := radius.Exchange(context.Background(), packet, "localhost:1812") if err != nil { t.Fatal(err) diff --git a/radius/util.go b/radius/util.go index 203d3a08..56791e56 100644 --- a/radius/util.go +++ b/radius/util.go @@ -15,26 +15,53 @@ package radius import ( + "fmt" + "time" + + "github.com/casdoor/casdoor/object" + "github.com/casdoor/casdoor/util" "layeh.com/radius" "layeh.com/radius/rfc2865" + "layeh.com/radius/rfc2866" + "layeh.com/radius/rfc2869" ) -const ( - OrganizationVendorID = uint32(100) -) - -func parseOrganization(p *radius.Packet) string { - for _, avp := range p.Attributes { - if avp.Type == rfc2865.VendorSpecific_Type { - attr := avp.Attribute - vendorId, value, err := radius.VendorSpecific(attr) - if err != nil { - return "" - } - if vendorId == OrganizationVendorID { - return string(value) - } - } +func GetAccountingFromRequest(r *radius.Request) *object.RadiusAccounting { + acctInputOctets := int(rfc2866.AcctInputOctets_Get(r.Packet)) + acctInputGigawords := int(rfc2869.AcctInputGigawords_Get(r.Packet)) + acctOutputOctets := int(rfc2866.AcctOutputOctets_Get(r.Packet)) + acctOutputGigawords := int(rfc2869.AcctOutputGigawords_Get(r.Packet)) + organization := rfc2865.Class_GetString(r.Packet) + getAcctStartTime := func(sessionTime int) time.Time { + m, _ := time.ParseDuration(fmt.Sprintf("-%ds", sessionTime)) + return time.Now().Add(m) } - return "" + ra := &object.RadiusAccounting{ + Owner: organization, + Name: "ra_" + util.GenerateId()[:6], + CreatedTime: time.Now(), + + Username: rfc2865.UserName_GetString(r.Packet), + ServiceType: int64(rfc2865.ServiceType_Get(r.Packet)), + + NasId: rfc2865.NASIdentifier_GetString(r.Packet), + NasIpAddr: rfc2865.NASIPAddress_Get(r.Packet).String(), + NasPortId: rfc2869.NASPortID_GetString(r.Packet), + NasPortType: int64(rfc2865.NASPortType_Get(r.Packet)), + NasPort: int64(rfc2865.NASPort_Get(r.Packet)), + + FramedIpAddr: rfc2865.FramedIPAddress_Get(r.Packet).String(), + FramedIpNetmask: rfc2865.FramedIPNetmask_Get(r.Packet).String(), + + AcctSessionId: rfc2866.AcctSessionID_GetString(r.Packet), + AcctSessionTime: int64(rfc2866.AcctSessionTime_Get(r.Packet)), + AcctInputTotal: int64(acctInputOctets) + int64(acctInputGigawords)*4*1024*1024*1024, + AcctOutputTotal: int64(acctOutputOctets) + int64(acctOutputGigawords)*4*1024*1024*1024, + AcctInputPackets: int64(rfc2866.AcctInputPackets_Get(r.Packet)), + AcctOutputPackets: int64(rfc2866.AcctInputPackets_Get(r.Packet)), + AcctStartTime: getAcctStartTime(int(rfc2866.AcctSessionTime_Get(r.Packet))), + AcctTerminateCause: int64(rfc2866.AcctTerminateCause_Get(r.Packet)), + LastUpdate: time.Now(), + } + return ra }